ferron_common/util/
ip_blocklist.rs

1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::net::{IpAddr, Ipv6Addr};
4
5use cidr::IpCidr;
6
7/// The IP blocklist
8#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct IpBlockList {
10  blocked_ips: HashSet<IpAddr>,
11  blocked_cidrs: HashSet<IpCidr>,
12}
13
14impl Default for IpBlockList {
15  fn default() -> Self {
16    Self::new()
17  }
18}
19
20impl Ord for IpBlockList {
21  fn cmp(&self, other: &Self) -> Ordering {
22    self
23      .blocked_ips
24      .iter()
25      .cmp(other.blocked_ips.iter())
26      .then(self.blocked_cidrs.iter().cmp(other.blocked_cidrs.iter()))
27  }
28}
29
30impl PartialOrd for IpBlockList {
31  fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
32    Some(self.cmp(other))
33  }
34}
35
36impl IpBlockList {
37  /// Creates a new empty block list
38  pub fn new() -> Self {
39    Self {
40      blocked_ips: HashSet::new(),
41      blocked_cidrs: HashSet::new(),
42    }
43  }
44
45  /// Loads the block list from a vector of IP address strings
46  pub fn load_from_vec(&mut self, ip_list: Vec<&str>) {
47    for ip_str in ip_list {
48      match ip_str {
49        "localhost" => {
50          self
51            .blocked_ips
52            .insert(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
53        }
54        _ => {
55          if let Ok(ip) = ip_str.parse::<IpAddr>() {
56            self.blocked_ips.insert(ip.to_canonical());
57          } else if let Ok(ip_cidr) = ip_str.parse::<IpCidr>() {
58            self.blocked_cidrs.insert(ip_cidr);
59          }
60        }
61      }
62    }
63  }
64
65  /// Checks if an IP address is blocked
66  pub fn is_blocked(&self, ip: IpAddr) -> bool {
67    self.blocked_ips.contains(&ip.to_canonical())
68      || self.blocked_cidrs.iter().any(|cidr| cidr.contains(&ip.to_canonical()))
69  }
70}
71
72#[cfg(test)]
73mod tests {
74  use super::*;
75
76  #[test]
77  fn test_ip_block_list() {
78    let mut block_list = IpBlockList::new();
79    block_list.load_from_vec(vec!["192.168.1.1", "10.0.0.1"]);
80
81    assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
82    assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
83    assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
84  }
85
86  #[test]
87  fn test_ip_cidr_block_list() {
88    let mut block_list = IpBlockList::new();
89    block_list.load_from_vec(vec!["192.168.1.0/24", "10.0.0.0/8"]);
90
91    assert!(block_list.is_blocked("192.168.1.1".parse().unwrap()));
92    assert!(block_list.is_blocked("10.0.0.1".parse().unwrap()));
93    assert!(!block_list.is_blocked("8.8.8.8".parse().unwrap()));
94  }
95}