ferron_common/util/
ip_blocklist.rs

1use std::collections::HashSet;
2use std::net::{IpAddr, Ipv6Addr};
3
4use cidr::IpCidr;
5
6/// The IP blocklist
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct IpBlockList {
9  blocked_ips: HashSet<IpAddr>,
10  blocked_cidrs: HashSet<IpCidr>,
11}
12
13impl Default for IpBlockList {
14  fn default() -> Self {
15    Self::new()
16  }
17}
18
19impl IpBlockList {
20  /// Creates a new empty block list
21  pub fn new() -> Self {
22    Self {
23      blocked_ips: HashSet::new(),
24      blocked_cidrs: HashSet::new(),
25    }
26  }
27
28  /// Loads the block list from a vector of IP address strings
29  pub fn load_from_vec(&mut self, ip_list: Vec<&str>) {
30    for ip_str in ip_list {
31      match ip_str {
32        "localhost" => {
33          self
34            .blocked_ips
35            .insert(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
36        }
37        _ => {
38          if let Ok(ip) = ip_str.parse::<IpAddr>() {
39            self.blocked_ips.insert(ip.to_canonical());
40          } else if let Ok(ip_cidr) = ip_str.parse::<IpCidr>() {
41            self.blocked_cidrs.insert(ip_cidr);
42          }
43        }
44      }
45    }
46  }
47
48  /// Checks if an IP address is blocked
49  pub fn is_blocked(&self, ip: IpAddr) -> bool {
50    self.blocked_ips.contains(&ip.to_canonical())
51      || self.blocked_cidrs.iter().any(|cidr| cidr.contains(&ip.to_canonical()))
52  }
53}
54
55#[cfg(test)]
56mod tests {
57  use super::*;
58
59  #[test]
60  fn test_ip_block_list() {
61    let mut block_list = IpBlockList::new();
62    block_list.load_from_vec(vec!["192.168.1.1", "10.0.0.1"]);
63
64    assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
65    assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
66    assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
67  }
68
69  #[test]
70  fn test_ip_cidr_block_list() {
71    let mut block_list = IpBlockList::new();
72    block_list.load_from_vec(vec!["192.168.1.0/24", "10.0.0.0/8"]);
73
74    assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
75    assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
76    assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
77  }
78}