1use std::collections::{BTreeMap, VecDeque};
2
3use ferron_common::config::Conditional;
4
5use crate::config::lookup::conditionals::{match_conditional, ConditionMatchData};
6
7#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
9pub enum ConfigFilterTreeSingleKey {
10 IsHostConfiguration,
12 Port(u16),
14 IPv4Octet(u8),
16 IPv6Octet(u8),
18 IsLocalhost,
20 HostDomainLevel(String),
22 HostDomainLevelWildcard,
24 LocationSegment(String),
26 Conditional(Conditional),
28 }
31
32impl ConfigFilterTreeSingleKey {
33 pub fn is_predicate(&self) -> bool {
35 matches!(self, Self::HostDomainLevelWildcard | Self::Conditional(_))
36 }
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
41struct ConfigFilterTreeMultiKey(Vec<ConfigFilterTreeSingleKey>);
42
43#[allow(clippy::non_canonical_partial_ord_impl)]
44impl PartialOrd for ConfigFilterTreeMultiKey {
45 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
46 for i in 0..self.0.len().max(other.0.len()) {
47 let self_element = self.0.get(i);
48 let other_element = other.0.get(i);
49 match (self_element, other_element) {
50 (Some(a), Some(b)) => {
51 let cmp = a.cmp(b);
52 if cmp != std::cmp::Ordering::Equal {
53 return Some(cmp);
54 }
55 }
56 _ => return None,
57 }
58 }
59 Some(std::cmp::Ordering::Equal)
60 }
61}
62
63impl Ord for ConfigFilterTreeMultiKey {
64 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
65 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
72 }
73}
74
75#[derive(Debug)]
77struct ConfigFilterTreeNode<T> {
78 pub value: Option<T>,
80 pub children_fixed: BTreeMap<ConfigFilterTreeMultiKey, ConfigFilterTreeNode<T>>,
82 pub children_predicate: BTreeMap<ConfigFilterTreeSingleKey, ConfigFilterTreeNode<T>>,
84}
85
86#[derive(Debug)]
90pub struct ConfigFilterTree<T> {
91 root: ConfigFilterTreeNode<T>,
93}
94
95impl<T> ConfigFilterTree<T> {
96 pub fn new() -> Self {
98 Self {
99 root: ConfigFilterTreeNode {
100 value: None,
101 children_fixed: BTreeMap::new(),
102 children_predicate: BTreeMap::new(),
103 },
104 }
105 }
106
107 #[allow(dead_code)]
109 pub fn insert(&mut self, key: Vec<ConfigFilterTreeSingleKey>, value: T) {
110 self.insert_node(key).replace(value);
111 }
112
113 pub fn insert_node(&mut self, key: Vec<ConfigFilterTreeSingleKey>) -> &mut Option<T> {
115 let mut current_node = &mut self.root;
116 let mut key_iter = key.into_iter();
117 let mut key_option = key_iter.next();
118 while let Some(key) = key_option.take() {
119 if key.is_predicate() {
120 match current_node.children_predicate.entry(key) {
122 std::collections::btree_map::Entry::Occupied(entry) => {
123 current_node = entry.into_mut();
124 }
125 std::collections::btree_map::Entry::Vacant(entry) => {
126 current_node = entry.insert(ConfigFilterTreeNode {
127 value: None,
128 children_fixed: BTreeMap::new(),
129 children_predicate: BTreeMap::new(),
130 });
131 }
132 }
133 key_option = key_iter.next();
134 } else {
135 let mut multi_key = ConfigFilterTreeMultiKey(vec![key]);
137 match current_node.children_fixed.entry(multi_key) {
138 std::collections::btree_map::Entry::Occupied(mut entry) => {
139 let entry_key = entry.key();
140 for i in 1..=entry_key.0.len() {
141 if i == entry_key.0.len() {
142 key_option = key_iter.next();
145 current_node = unsafe {
149 std::mem::transmute::<&mut ConfigFilterTreeNode<T>, &mut ConfigFilterTreeNode<T>>(entry.get_mut())
150 };
151 break;
152 }
153 key_option = key_iter.next();
154 let mut break_multi_key = false;
155 if let Some(key) = &key_option {
156 if key != &entry_key.0[i] {
157 break_multi_key = true;
159 }
160 } else {
161 break_multi_key = true;
163 }
164 if break_multi_key {
165 let (mut entry_key, entry_value) = entry.remove_entry();
167 let entry_key_right = ConfigFilterTreeMultiKey(entry_key.0.split_off(i));
168 #[allow(clippy::mutable_key_type)]
169 let mut new_children_fixed = BTreeMap::new();
170 new_children_fixed.insert(entry_key_right, entry_value);
171 match current_node.children_fixed.entry(entry_key) {
172 std::collections::btree_map::Entry::Occupied(entry) => {
173 current_node = entry.into_mut();
174 }
175 std::collections::btree_map::Entry::Vacant(entry) => {
176 current_node = entry.insert(ConfigFilterTreeNode {
177 value: None,
178 children_fixed: new_children_fixed,
179 children_predicate: BTreeMap::new(),
180 });
181 }
182 }
183 break;
184 }
185 }
186 }
187 std::collections::btree_map::Entry::Vacant(entry) => {
188 multi_key = entry.into_key();
189
190 key_option = key_iter.next();
191 while let Some(key) = &key_option {
192 if !key.is_predicate() {
193 let key = key_option.take().expect("key_option should be Some here");
194 multi_key.0.push(key);
195 key_option = key_iter.next();
196 } else {
197 break;
198 }
199 }
200
201 match current_node.children_fixed.entry(multi_key) {
202 std::collections::btree_map::Entry::Occupied(entry) => {
203 current_node = entry.into_mut();
204 }
205 std::collections::btree_map::Entry::Vacant(entry) => {
206 current_node = entry.insert(ConfigFilterTreeNode {
207 value: None,
208 children_fixed: BTreeMap::new(),
209 children_predicate: BTreeMap::new(),
210 });
211 }
212 };
213 }
214 }
215 }
216 }
217
218 &mut current_node.value
219 }
220
221 pub fn get<'a, 'b>(
223 &'a self,
224 key: Vec<ConfigFilterTreeSingleKey>,
225 condition_match_data: Option<ConditionMatchData<'b>>,
226 ) -> Result<Option<&'a T>, Box<dyn std::error::Error + Send + Sync>> {
227 let mut current_nodes = VecDeque::new();
228 current_nodes.push_back((&self.root, ConfigFilterTreeMultiKey(key)));
229 let mut value = current_nodes[0].0.value.as_ref();
230 while let Some((mut current_node, mut key)) = current_nodes.pop_front() {
231 while !key.0.is_empty() {
232 let key_end = key.0.split_off(1);
233 let mut partial_key = ConfigFilterTreeMultiKey(key.0);
234 key.0 = key_end;
235 if let Some((child_key, child)) = current_node.children_fixed.get_key_value(&partial_key) {
236 current_node = child;
238 let mut index = 0;
239 let mut secondary_index = 1;
240 let mut is_matching = true;
241 while let (Some(key_single), Some(child_key_single)) = (key.0.get(index), child_key.0.get(secondary_index)) {
242 if std::mem::discriminant(key_single) == std::mem::discriminant(child_key_single) {
243 if key_single != child_key_single {
245 is_matching = false;
247 break;
248 } else {
249 secondary_index += 1;
251 }
252 }
253 index += 1;
254 }
255 if !is_matching || (index >= key.0.len() && secondary_index < child_key.0.len()) {
256 break;
258 }
259 key.0 = key.0.split_off(index);
260 if current_node.value.is_some() {
261 value = current_node.value.as_ref();
262 }
263 continue;
264 }
265
266 let partial_key = partial_key.0.remove(0);
267 if partial_key.is_predicate() {
268 if let Some(child) = current_node.children_predicate.get(&partial_key) {
270 current_node = child;
272 if current_node.value.is_some() {
273 value = current_node.value.as_ref();
274 }
275 continue;
276 }
277 }
278
279 let mut conditional_predicate_matched = false;
281 let mut wildcard_matched = false;
282 for (predicate_key, child) in ¤t_node.children_predicate {
283 if predicate_key.is_predicate() {
284 if predicate_key == &ConfigFilterTreeSingleKey::HostDomainLevelWildcard
286 && matches!(partial_key, ConfigFilterTreeSingleKey::HostDomainLevel(_))
287 {
288 current_node = child;
290 let mut index = 0;
291 while let Some(ConfigFilterTreeSingleKey::HostDomainLevel(_)) = key.0.get(index) {
292 index += 1;
293 }
294 key.0 = key.0.split_off(index);
295 if current_node.value.is_some() {
296 value = current_node.value.as_ref();
297 }
298 wildcard_matched = predicate_key == &ConfigFilterTreeSingleKey::HostDomainLevelWildcard;
299 break;
304 } else if let ConfigFilterTreeSingleKey::Conditional(conditional) = predicate_key {
305 if let Some(condition_match_data) = condition_match_data.as_ref() {
307 if match_conditional(conditional, condition_match_data)? {
308 let current_node = child;
309 if current_node.value.is_some() {
310 value = current_node.value.as_ref();
311 }
312 current_nodes.push_back((current_node, key.clone()));
315 conditional_predicate_matched = true;
316 }
317 }
318 }
319 }
320 }
321 if conditional_predicate_matched {
322 break;
323 }
324 let have_path_end =
325 matches!(key.0.first(), Some(&ConfigFilterTreeSingleKey::LocationSegment(_))) && key.0.get(1).is_none();
326
327 if !wildcard_matched && have_path_end {
328 break;
331 }
332
333 }
335 }
336
337 Ok(value)
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_basic_get() {
347 let mut tree = ConfigFilterTree::new();
348 tree.insert(
349 vec![
350 ConfigFilterTreeSingleKey::Port(80),
351 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
352 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
353 ConfigFilterTreeSingleKey::HostDomainLevelWildcard,
354 ],
355 "Example",
356 );
357 tree.insert(
358 vec![
359 ConfigFilterTreeSingleKey::Port(80),
360 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
361 ConfigFilterTreeSingleKey::HostDomainLevel("example2".to_string()),
362 ConfigFilterTreeSingleKey::HostDomainLevel("www".to_string()),
363 ],
364 "Example 2",
365 );
366
367 assert_eq!(
368 tree
369 .get(
370 vec![
371 ConfigFilterTreeSingleKey::Port(80),
372 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
373 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
374 ConfigFilterTreeSingleKey::HostDomainLevel("www".to_string()),
375 ],
376 None
377 )
378 .unwrap(),
379 Some(&"Example")
380 );
381
382 assert_eq!(
383 tree
384 .get(
385 vec![
386 ConfigFilterTreeSingleKey::Port(80),
387 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
388 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
389 ConfigFilterTreeSingleKey::HostDomainLevel("subsite".to_string()),
390 ],
391 None
392 )
393 .unwrap(),
394 Some(&"Example")
395 );
396
397 assert_eq!(
398 tree
399 .get(
400 vec![
401 ConfigFilterTreeSingleKey::Port(80),
402 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
403 ConfigFilterTreeSingleKey::HostDomainLevel("example2".to_string()),
404 ConfigFilterTreeSingleKey::HostDomainLevel("www".to_string()),
405 ],
406 None
407 )
408 .unwrap(),
409 Some(&"Example 2")
410 );
411
412 assert_eq!(
413 tree
414 .get(
415 vec![
416 ConfigFilterTreeSingleKey::Port(80),
417 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
418 ConfigFilterTreeSingleKey::HostDomainLevel("example3".to_string()),
419 ConfigFilterTreeSingleKey::HostDomainLevel("www".to_string()),
420 ],
421 None
422 )
423 .unwrap(),
424 None
425 );
426
427 assert_eq!(
428 tree
429 .get(
430 vec![
431 ConfigFilterTreeSingleKey::Port(80),
432 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
433 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
434 ],
435 None
436 )
437 .unwrap(),
438 None
439 );
440 }
441
442 #[test]
443 fn test_empty_tree() {
444 let tree: ConfigFilterTree<&str> = ConfigFilterTree::new();
445 assert_eq!(tree.get(vec![], None).unwrap(), None);
446 }
447
448 #[test]
449 fn test_insert_and_get_single_key() {
450 let mut tree = ConfigFilterTree::new();
451 tree.insert(vec![ConfigFilterTreeSingleKey::Port(80)], "Port 80");
452 assert_eq!(
453 tree.get(vec![ConfigFilterTreeSingleKey::Port(80)], None).unwrap(),
454 Some(&"Port 80")
455 );
456 }
457
458 #[test]
459 fn test_insert_and_get_multi_key() {
460 let mut tree = ConfigFilterTree::new();
461 tree.insert(
462 vec![
463 ConfigFilterTreeSingleKey::Port(80),
464 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
465 ],
466 "Port 80, com",
467 );
468 assert_eq!(
469 tree
470 .get(
471 vec![
472 ConfigFilterTreeSingleKey::Port(80),
473 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
474 ],
475 None
476 )
477 .unwrap(),
478 Some(&"Port 80, com")
479 );
480 }
481
482 #[test]
483 fn test_wildcard_matching() {
484 let mut tree = ConfigFilterTree::new();
485 tree.insert(
486 vec![
487 ConfigFilterTreeSingleKey::Port(80),
488 ConfigFilterTreeSingleKey::HostDomainLevelWildcard,
489 ],
490 "Wildcard",
491 );
492 assert_eq!(
493 tree
494 .get(
495 vec![
496 ConfigFilterTreeSingleKey::Port(80),
497 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
498 ],
499 None
500 )
501 .unwrap(),
502 Some(&"Wildcard")
503 );
504 }
505
506 #[test]
507 fn test_partial_key_matching() {
508 let mut tree = ConfigFilterTree::new();
509 tree.insert(
510 vec![
511 ConfigFilterTreeSingleKey::Port(80),
512 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
513 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
514 ],
515 "Partial",
516 );
517 assert_eq!(
518 tree
519 .get(
520 vec![
521 ConfigFilterTreeSingleKey::Port(80),
522 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
523 ],
524 None
525 )
526 .unwrap(),
527 None
528 );
529 }
530
531 #[test]
532 fn test_overlapping_keys() {
533 let mut tree = ConfigFilterTree::new();
534 tree.insert(
535 vec![
536 ConfigFilterTreeSingleKey::Port(80),
537 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
538 ],
539 "First",
540 );
541 tree.insert(
542 vec![
543 ConfigFilterTreeSingleKey::Port(80),
544 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
545 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
546 ],
547 "Second",
548 );
549 assert_eq!(
550 tree
551 .get(
552 vec![
553 ConfigFilterTreeSingleKey::Port(80),
554 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
555 ],
556 None
557 )
558 .unwrap(),
559 Some(&"First")
560 );
561 assert_eq!(
562 tree
563 .get(
564 vec![
565 ConfigFilterTreeSingleKey::Port(80),
566 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
567 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
568 ],
569 None
570 )
571 .unwrap(),
572 Some(&"Second")
573 );
574 }
575
576 #[test]
577 fn test_mixed_predicate_and_fixed_keys() {
578 let mut tree = ConfigFilterTree::new();
579 tree.insert(
580 vec![
581 ConfigFilterTreeSingleKey::Port(80),
582 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
583 ConfigFilterTreeSingleKey::HostDomainLevelWildcard,
584 ],
585 "Mixed",
586 );
587 assert_eq!(
588 tree
589 .get(
590 vec![
591 ConfigFilterTreeSingleKey::Port(80),
592 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
593 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
594 ],
595 None
596 )
597 .unwrap(),
598 Some(&"Mixed")
599 );
600 }
601
602 #[test]
603 fn test_keys_with_redundant_in_between() {
604 let mut tree = ConfigFilterTree::new();
605 tree.insert(
606 vec![
607 ConfigFilterTreeSingleKey::Port(80),
608 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
609 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
610 ],
611 "Value",
612 );
613 assert_eq!(
614 tree
615 .get(
616 vec![
617 ConfigFilterTreeSingleKey::Port(80),
618 ConfigFilterTreeSingleKey::IPv4Octet(127),
619 ConfigFilterTreeSingleKey::IPv4Octet(0),
620 ConfigFilterTreeSingleKey::IPv4Octet(0),
621 ConfigFilterTreeSingleKey::IPv4Octet(1),
622 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
623 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
624 ],
625 None
626 )
627 .unwrap(),
628 Some(&"Value")
629 );
630 }
631
632 #[test]
633 fn test_wildcard_domain_location() {
634 let mut tree = ConfigFilterTree::new();
636 tree.insert(
637 vec![
638 ConfigFilterTreeSingleKey::Port(80),
639 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
640 ConfigFilterTreeSingleKey::HostDomainLevelWildcard,
641 ConfigFilterTreeSingleKey::LocationSegment("".to_string()),
642 ConfigFilterTreeSingleKey::LocationSegment("test".to_string()),
643 ],
644 "Value1",
645 );
646 tree.insert(
647 vec![
648 ConfigFilterTreeSingleKey::Port(80),
649 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
650 ConfigFilterTreeSingleKey::HostDomainLevelWildcard,
651 ConfigFilterTreeSingleKey::LocationSegment("".to_string()),
652 ],
653 "Value2",
654 );
655
656 assert_eq!(
657 tree
658 .get(
659 vec![
660 ConfigFilterTreeSingleKey::Port(80),
661 ConfigFilterTreeSingleKey::IPv4Octet(127),
662 ConfigFilterTreeSingleKey::IPv4Octet(0),
663 ConfigFilterTreeSingleKey::IPv4Octet(0),
664 ConfigFilterTreeSingleKey::IPv4Octet(1),
665 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
666 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
667 ConfigFilterTreeSingleKey::LocationSegment("".to_string()),
668 ConfigFilterTreeSingleKey::LocationSegment("test".to_string())
669 ],
670 None
671 )
672 .unwrap(),
673 Some(&"Value1")
674 );
675 assert_eq!(
676 tree
677 .get(
678 vec![
679 ConfigFilterTreeSingleKey::Port(80),
680 ConfigFilterTreeSingleKey::IPv4Octet(127),
681 ConfigFilterTreeSingleKey::IPv4Octet(0),
682 ConfigFilterTreeSingleKey::IPv4Octet(0),
683 ConfigFilterTreeSingleKey::IPv4Octet(1),
684 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
685 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
686 ConfigFilterTreeSingleKey::LocationSegment("".to_string())
687 ],
688 None
689 )
690 .unwrap(),
691 Some(&"Value2")
692 );
693 assert_eq!(
694 tree
695 .get(
696 vec![
697 ConfigFilterTreeSingleKey::Port(80),
698 ConfigFilterTreeSingleKey::IPv4Octet(127),
699 ConfigFilterTreeSingleKey::IPv4Octet(0),
700 ConfigFilterTreeSingleKey::IPv4Octet(0),
701 ConfigFilterTreeSingleKey::IPv4Octet(1),
702 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
703 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
704 ConfigFilterTreeSingleKey::LocationSegment("".to_string()),
705 ConfigFilterTreeSingleKey::LocationSegment("a".to_string()),
706 ConfigFilterTreeSingleKey::LocationSegment("test".to_string())
707 ],
708 None
709 )
710 .unwrap(),
711 Some(&"Value2")
712 );
713 assert_eq!(
714 tree
715 .get(
716 vec![
717 ConfigFilterTreeSingleKey::Port(80),
718 ConfigFilterTreeSingleKey::IPv4Octet(127),
719 ConfigFilterTreeSingleKey::IPv4Octet(0),
720 ConfigFilterTreeSingleKey::IPv4Octet(0),
721 ConfigFilterTreeSingleKey::IPv4Octet(1),
722 ConfigFilterTreeSingleKey::HostDomainLevel("com".to_string()),
723 ConfigFilterTreeSingleKey::HostDomainLevel("example".to_string()),
724 ConfigFilterTreeSingleKey::LocationSegment("".to_string()),
725 ConfigFilterTreeSingleKey::LocationSegment("a".to_string())
726 ],
727 None
728 )
729 .unwrap(),
730 Some(&"Value2")
731 );
732 }
733}