ferron_common/http_proxy/
load_balancer.rs

1use std::sync::atomic::Ordering;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use super::{LoadBalancerAlgorithmInner, ProxyToKey, ProxyToKeyInner, UpstreamInner};
7use crate::util::TtlCache;
8
9/// Selects an index for a backend server based on the load balancing algorithm.
10async fn select_backend_index(
11  load_balancer_algorithm: &LoadBalancerAlgorithmInner,
12  backends: &[ProxyToKeyInner],
13) -> usize {
14  match load_balancer_algorithm {
15    LoadBalancerAlgorithmInner::TwoRandomChoices(connection_track) => {
16      // *1 - random choice #1
17      // *2 - random choice #2
18      let random_choice1 = rand::random_range(..backends.len());
19      let mut random_choice2 = if backends.len() > 1 {
20        rand::random_range(..(backends.len() - 1))
21      } else {
22        0
23      };
24      if backends.len() > 1 && random_choice2 >= random_choice1 {
25        random_choice2 += 1;
26      }
27      let backend1 = &backends[random_choice1];
28      let backend2 = &backends[random_choice2];
29      let connection_track_read = connection_track.read().await;
30      let connection_count_option1 = connection_track_read
31        .get(&backend1.0)
32        .map(|connection_count| Arc::strong_count(connection_count) - 1);
33      let connection_count_option2 = connection_track_read
34        .get(&backend2.0)
35        .map(|connection_count| Arc::strong_count(connection_count) - 1);
36      drop(connection_track_read);
37      let connection_count1 = if let Some(count) = connection_count_option1 {
38        count
39      } else {
40        connection_track.write().await.insert(backend1.0.clone(), Arc::new(()));
41        0
42      };
43      let connection_count2 = if let Some(count) = connection_count_option2 {
44        count
45      } else {
46        connection_track.write().await.insert(backend2.0.clone(), Arc::new(()));
47        0
48      };
49      if connection_count2 >= connection_count1 {
50        random_choice1
51      } else {
52        random_choice2
53      }
54    }
55    LoadBalancerAlgorithmInner::LeastConnections(connection_track) => {
56      let mut min_indexes = Vec::new();
57      let mut min_connections = None;
58      for (index, (upstream, _, _)) in backends.iter().enumerate() {
59        let connection_track_read = connection_track.read().await;
60        let connection_count = if let Some(connection_count) = connection_track_read.get(upstream) {
61          Arc::strong_count(connection_count) - 1
62        } else {
63          drop(connection_track_read);
64          connection_track.write().await.insert((*upstream).clone(), Arc::new(()));
65          0
66        };
67        if min_connections.is_none_or(|min| connection_count < min) {
68          // Less connections than minimum
69          min_indexes = vec![index];
70          min_connections = Some(connection_count);
71        } else if min_connections == Some(connection_count) {
72          // Same amount of connections
73          min_indexes.push(index);
74        }
75      }
76      match min_indexes.len() {
77        0 => 0, // Possible edge case
78        1 => min_indexes[0],
79        _ => min_indexes[rand::random_range(0..min_indexes.len())],
80      }
81    }
82    LoadBalancerAlgorithmInner::RoundRobin(round_robin_index) => {
83      // Add to round robin index, then modulo the length of backends to prevent overflow
84      round_robin_index.fetch_add(1, Ordering::Relaxed) % backends.len()
85    }
86    LoadBalancerAlgorithmInner::Random => rand::random_range(..backends.len()),
87  }
88}
89
90/// Determines which backend server to proxy the request to.
91#[inline]
92pub(super) async fn determine_proxy_to(
93  proxy_to_vector: &mut Vec<ProxyToKeyInner>,
94  failed_backends: &RwLock<TtlCache<UpstreamInner, u64>>,
95  enable_health_check: bool,
96  health_check_max_fails: u64,
97  load_balancer_algorithm: &LoadBalancerAlgorithmInner,
98) -> Option<ProxyToKeyInner> {
99  let mut proxy_to = None;
100
101  if proxy_to_vector.is_empty() {
102    return None;
103  } else if proxy_to_vector.len() == 1 {
104    let proxy_to_borrowed = proxy_to_vector.remove(0);
105    let upstream = proxy_to_borrowed.0;
106    let local_limit_index = proxy_to_borrowed.1;
107    let keepalive_idle_timeout = proxy_to_borrowed.2;
108    proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
109  } else if enable_health_check {
110    loop {
111      if !proxy_to_vector.is_empty() {
112        let index = select_backend_index(load_balancer_algorithm, proxy_to_vector).await;
113        let proxy_to_borrowed = proxy_to_vector.remove(index);
114        let upstream = proxy_to_borrowed.0;
115        let local_limit_index = proxy_to_borrowed.1;
116        let keepalive_idle_timeout = proxy_to_borrowed.2;
117        let failed_backends_read = failed_backends.read().await;
118        let failed_backend_fails_option = failed_backends_read.get(&upstream);
119        proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
120        let failed_backend_fails = if let Some(fails) = failed_backend_fails_option {
121          fails
122        } else {
123          break;
124        };
125        if failed_backend_fails <= health_check_max_fails {
126          break;
127        }
128      } else {
129        break;
130      }
131    }
132  } else if !proxy_to_vector.is_empty() {
133    let index = select_backend_index(load_balancer_algorithm, proxy_to_vector).await;
134    let proxy_to_borrowed = proxy_to_vector.remove(index);
135    let upstream = proxy_to_borrowed.0;
136    let local_limit_index = proxy_to_borrowed.1;
137    let keepalive_idle_timeout = proxy_to_borrowed.2;
138    proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
139  }
140
141  proxy_to
142}
143
144/// Resolves inner upstreams from a list of upstreams.
145pub(super) async fn resolve_upstreams(
146  proxy_to: &[ProxyToKey],
147  failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
148  health_check_max_fails: u64,
149) -> Vec<ProxyToKeyInner> {
150  let mut upstreams = Vec::new();
151  for proxy_to in proxy_to {
152    let upstream = proxy_to
153      .0
154      .resolve(failed_backends.clone(), health_check_max_fails)
155      .await;
156    for upstream in upstream {
157      upstreams.push((upstream, proxy_to.1, proxy_to.2));
158    }
159  }
160  upstreams
161}
162
163#[cfg(test)]
164mod tests {
165  use std::collections::HashMap;
166  use std::future::Future;
167  use std::sync::atomic::AtomicUsize;
168  use std::time::Duration;
169
170  use super::*;
171
172  fn run_async<T>(future: impl Future<Output = T>) -> T {
173    tokio::runtime::Builder::new_current_thread()
174      .enable_all()
175      .build()
176      .expect("runtime should be created")
177      .block_on(future)
178  }
179
180  fn upstream(proxy_to: &str) -> UpstreamInner {
181    UpstreamInner {
182      proxy_to: proxy_to.to_string(),
183      proxy_unix: None,
184    }
185  }
186
187  #[test]
188  fn round_robin_cycles_through_backends() {
189    run_async(async {
190      let backends = vec![
191        (upstream("http://backend-1"), None, None),
192        (upstream("http://backend-2"), None, None),
193        (upstream("http://backend-3"), None, None),
194      ];
195      let algorithm = LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0)));
196
197      assert_eq!(select_backend_index(&algorithm, &backends).await, 0);
198      assert_eq!(select_backend_index(&algorithm, &backends).await, 1);
199      assert_eq!(select_backend_index(&algorithm, &backends).await, 2);
200      assert_eq!(select_backend_index(&algorithm, &backends).await, 0);
201    });
202  }
203
204  #[test]
205  fn least_connections_picks_backend_with_lowest_connection_count() {
206    run_async(async {
207      let heavily_loaded = upstream("http://backend-1");
208      let least_loaded = upstream("http://backend-2");
209      let moderately_loaded = upstream("http://backend-3");
210
211      let heavily_loaded_tracker = Arc::new(());
212      let _heavy_1 = heavily_loaded_tracker.clone();
213      let _heavy_2 = heavily_loaded_tracker.clone();
214      let moderately_loaded_tracker = Arc::new(());
215      let _moderate_1 = moderately_loaded_tracker.clone();
216
217      let connection_track = Arc::new(RwLock::new(HashMap::new()));
218      {
219        let mut connection_track_write = connection_track.write().await;
220        connection_track_write.insert(heavily_loaded.clone(), heavily_loaded_tracker);
221        connection_track_write.insert(least_loaded.clone(), Arc::new(()));
222        connection_track_write.insert(moderately_loaded.clone(), moderately_loaded_tracker);
223      }
224
225      let backends = vec![
226        (heavily_loaded, None, None),
227        (least_loaded, None, None),
228        (moderately_loaded, None, None),
229      ];
230      let algorithm = LoadBalancerAlgorithmInner::LeastConnections(connection_track);
231
232      for _ in 0..32 {
233        let selected_index = select_backend_index(&algorithm, &backends).await;
234        assert_eq!(selected_index, 1);
235      }
236    });
237  }
238
239  #[test]
240  fn determine_proxy_to_skips_unhealthy_backend_when_alternatives_exist() {
241    run_async(async {
242      let unhealthy = upstream("http://backend-unhealthy");
243      let healthy = upstream("http://backend-healthy");
244      let mut proxy_to_vector = vec![(unhealthy.clone(), None, None), (healthy.clone(), None, None)];
245
246      let failed_backends = RwLock::new(TtlCache::new(Duration::from_secs(60)));
247      {
248        let mut failed_backends_write = failed_backends.write().await;
249        failed_backends_write.insert(unhealthy, 4);
250      }
251
252      let algorithm = LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0)));
253      let selected = determine_proxy_to(&mut proxy_to_vector, &failed_backends, true, 3, &algorithm).await;
254
255      assert!(selected.is_some());
256      let (selected_upstream, _, _) = selected.expect("a backend should be selected");
257      assert!(selected_upstream == healthy);
258    });
259  }
260}