ferron_common/util/
ip_blocklist.rs1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv6Addr};
4
5use cidr::IpCidr;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct IpBlockList {
10 blocked_ips: HashSet<IpAddr>,
11 blocked_cidrs: HashSet<IpCidr>,
12}
13
14impl Default for IpBlockList {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl Ord for IpBlockList {
21 fn cmp(&self, other: &Self) -> Ordering {
22 self
23 .blocked_ips
24 .iter()
25 .cmp(other.blocked_ips.iter())
26 .then(self.blocked_cidrs.iter().cmp(other.blocked_cidrs.iter()))
27 }
28}
29
30impl PartialOrd for IpBlockList {
31 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
32 Some(self.cmp(other))
33 }
34}
35
36impl IpBlockList {
37 pub fn new() -> Self {
39 Self {
40 blocked_ips: HashSet::new(),
41 blocked_cidrs: HashSet::new(),
42 }
43 }
44
45 pub fn load_from_vec(&mut self, ip_list: Vec<&str>) {
47 for ip_str in ip_list {
48 match ip_str {
49 "localhost" => {
50 self
51 .blocked_ips
52 .insert(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
53 }
54 _ => {
55 if let Ok(ip) = ip_str.parse::<IpAddr>() {
56 self.blocked_ips.insert(ip.to_canonical());
57 } else if let Ok(ip_cidr) = ip_str.parse::<IpCidr>() {
58 self.blocked_cidrs.insert(ip_cidr);
59 }
60 }
61 }
62 }
63 }
64
65 pub fn is_blocked(&self, ip: IpAddr) -> bool {
67 self.blocked_ips.contains(&ip.to_canonical())
68 || self.blocked_cidrs.iter().any(|cidr| cidr.contains(&ip.to_canonical()))
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_ip_block_list() {
78 let mut block_list = IpBlockList::new();
79 block_list.load_from_vec(vec!["192.168.1.1", "10.0.0.1"]);
80
81 assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
82 assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
83 assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
84 }
85
86 #[test]
87 fn test_ip_cidr_block_list() {
88 let mut block_list = IpBlockList::new();
89 block_list.load_from_vec(vec!["192.168.1.0/24", "10.0.0.0/8"]);
90
91 assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
92 assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
93 assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
94 }
95}