ferron_common/http_proxy/
load_balancer.rs1use std::sync::atomic::Ordering;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use super::{LoadBalancerAlgorithmInner, ProxyToKey, ProxyToKeyInner, UpstreamInner};
7use crate::util::TtlCache;
8
9async fn select_backend_index(
11 load_balancer_algorithm: &LoadBalancerAlgorithmInner,
12 backends: &[ProxyToKeyInner],
13) -> usize {
14 match load_balancer_algorithm {
15 LoadBalancerAlgorithmInner::TwoRandomChoices(connection_track) => {
16 let random_choice1 = rand::random_range(..backends.len());
19 let mut random_choice2 = if backends.len() > 1 {
20 rand::random_range(..(backends.len() - 1))
21 } else {
22 0
23 };
24 if backends.len() > 1 && random_choice2 >= random_choice1 {
25 random_choice2 += 1;
26 }
27 let backend1 = &backends[random_choice1];
28 let backend2 = &backends[random_choice2];
29 let connection_track_read = connection_track.read().await;
30 let connection_count_option1 = connection_track_read
31 .get(&backend1.0)
32 .map(|connection_count| Arc::strong_count(connection_count) - 1);
33 let connection_count_option2 = connection_track_read
34 .get(&backend2.0)
35 .map(|connection_count| Arc::strong_count(connection_count) - 1);
36 drop(connection_track_read);
37 let connection_count1 = if let Some(count) = connection_count_option1 {
38 count
39 } else {
40 connection_track.write().await.insert(backend1.0.clone(), Arc::new(()));
41 0
42 };
43 let connection_count2 = if let Some(count) = connection_count_option2 {
44 count
45 } else {
46 connection_track.write().await.insert(backend2.0.clone(), Arc::new(()));
47 0
48 };
49 if connection_count2 >= connection_count1 {
50 random_choice1
51 } else {
52 random_choice2
53 }
54 }
55 LoadBalancerAlgorithmInner::LeastConnections(connection_track) => {
56 let mut min_indexes = Vec::new();
57 let mut min_connections = None;
58 for (index, (upstream, _, _)) in backends.iter().enumerate() {
59 let connection_track_read = connection_track.read().await;
60 let connection_count = if let Some(connection_count) = connection_track_read.get(upstream) {
61 Arc::strong_count(connection_count) - 1
62 } else {
63 drop(connection_track_read);
64 connection_track.write().await.insert((*upstream).clone(), Arc::new(()));
65 0
66 };
67 if min_connections.is_none_or(|min| connection_count < min) {
68 min_indexes = vec![index];
70 min_connections = Some(connection_count);
71 } else if min_connections == Some(connection_count) {
72 min_indexes.push(index);
74 }
75 }
76 match min_indexes.len() {
77 0 => 0, 1 => min_indexes[0],
79 _ => min_indexes[rand::random_range(0..min_indexes.len())],
80 }
81 }
82 LoadBalancerAlgorithmInner::RoundRobin(round_robin_index) => {
83 round_robin_index.fetch_add(1, Ordering::Relaxed) % backends.len()
85 }
86 LoadBalancerAlgorithmInner::Random => rand::random_range(..backends.len()),
87 }
88}
89
90#[inline]
92pub(super) async fn determine_proxy_to(
93 proxy_to_vector: &mut Vec<ProxyToKeyInner>,
94 failed_backends: &RwLock<TtlCache<UpstreamInner, u64>>,
95 enable_health_check: bool,
96 health_check_max_fails: u64,
97 load_balancer_algorithm: &LoadBalancerAlgorithmInner,
98) -> Option<ProxyToKeyInner> {
99 let mut proxy_to = None;
100
101 if proxy_to_vector.is_empty() {
102 return None;
103 } else if proxy_to_vector.len() == 1 {
104 let proxy_to_borrowed = proxy_to_vector.remove(0);
105 let upstream = proxy_to_borrowed.0;
106 let local_limit_index = proxy_to_borrowed.1;
107 let keepalive_idle_timeout = proxy_to_borrowed.2;
108 proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
109 } else if enable_health_check {
110 loop {
111 if !proxy_to_vector.is_empty() {
112 let index = select_backend_index(load_balancer_algorithm, proxy_to_vector).await;
113 let proxy_to_borrowed = proxy_to_vector.remove(index);
114 let upstream = proxy_to_borrowed.0;
115 let local_limit_index = proxy_to_borrowed.1;
116 let keepalive_idle_timeout = proxy_to_borrowed.2;
117 let failed_backends_read = failed_backends.read().await;
118 let failed_backend_fails_option = failed_backends_read.get(&upstream);
119 proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
120 let failed_backend_fails = if let Some(fails) = failed_backend_fails_option {
121 fails
122 } else {
123 break;
124 };
125 if failed_backend_fails <= health_check_max_fails {
126 break;
127 }
128 } else {
129 break;
130 }
131 }
132 } else if !proxy_to_vector.is_empty() {
133 let index = select_backend_index(load_balancer_algorithm, proxy_to_vector).await;
134 let proxy_to_borrowed = proxy_to_vector.remove(index);
135 let upstream = proxy_to_borrowed.0;
136 let local_limit_index = proxy_to_borrowed.1;
137 let keepalive_idle_timeout = proxy_to_borrowed.2;
138 proxy_to = Some((upstream, local_limit_index, keepalive_idle_timeout));
139 }
140
141 proxy_to
142}
143
144pub(super) async fn resolve_upstreams(
146 proxy_to: &[ProxyToKey],
147 failed_backends: Arc<RwLock<TtlCache<UpstreamInner, u64>>>,
148 health_check_max_fails: u64,
149) -> Vec<ProxyToKeyInner> {
150 let mut upstreams = Vec::new();
151 for proxy_to in proxy_to {
152 let upstream = proxy_to
153 .0
154 .resolve(failed_backends.clone(), health_check_max_fails)
155 .await;
156 for upstream in upstream {
157 upstreams.push((upstream, proxy_to.1, proxy_to.2));
158 }
159 }
160 upstreams
161}
162
163#[cfg(test)]
164mod tests {
165 use std::collections::HashMap;
166 use std::future::Future;
167 use std::sync::atomic::AtomicUsize;
168 use std::time::Duration;
169
170 use super::*;
171
172 fn run_async<T>(future: impl Future<Output = T>) -> T {
173 tokio::runtime::Builder::new_current_thread()
174 .enable_all()
175 .build()
176 .expect("runtime should be created")
177 .block_on(future)
178 }
179
180 fn upstream(proxy_to: &str) -> UpstreamInner {
181 UpstreamInner {
182 proxy_to: proxy_to.to_string(),
183 proxy_unix: None,
184 }
185 }
186
187 #[test]
188 fn round_robin_cycles_through_backends() {
189 run_async(async {
190 let backends = vec![
191 (upstream("http://backend-1"), None, None),
192 (upstream("http://backend-2"), None, None),
193 (upstream("http://backend-3"), None, None),
194 ];
195 let algorithm = LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0)));
196
197 assert_eq!(select_backend_index(&algorithm, &backends).await, 0);
198 assert_eq!(select_backend_index(&algorithm, &backends).await, 1);
199 assert_eq!(select_backend_index(&algorithm, &backends).await, 2);
200 assert_eq!(select_backend_index(&algorithm, &backends).await, 0);
201 });
202 }
203
204 #[test]
205 fn least_connections_picks_backend_with_lowest_connection_count() {
206 run_async(async {
207 let heavily_loaded = upstream("http://backend-1");
208 let least_loaded = upstream("http://backend-2");
209 let moderately_loaded = upstream("http://backend-3");
210
211 let heavily_loaded_tracker = Arc::new(());
212 let _heavy_1 = heavily_loaded_tracker.clone();
213 let _heavy_2 = heavily_loaded_tracker.clone();
214 let moderately_loaded_tracker = Arc::new(());
215 let _moderate_1 = moderately_loaded_tracker.clone();
216
217 let connection_track = Arc::new(RwLock::new(HashMap::new()));
218 {
219 let mut connection_track_write = connection_track.write().await;
220 connection_track_write.insert(heavily_loaded.clone(), heavily_loaded_tracker);
221 connection_track_write.insert(least_loaded.clone(), Arc::new(()));
222 connection_track_write.insert(moderately_loaded.clone(), moderately_loaded_tracker);
223 }
224
225 let backends = vec![
226 (heavily_loaded, None, None),
227 (least_loaded, None, None),
228 (moderately_loaded, None, None),
229 ];
230 let algorithm = LoadBalancerAlgorithmInner::LeastConnections(connection_track);
231
232 for _ in 0..32 {
233 let selected_index = select_backend_index(&algorithm, &backends).await;
234 assert_eq!(selected_index, 1);
235 }
236 });
237 }
238
239 #[test]
240 fn determine_proxy_to_skips_unhealthy_backend_when_alternatives_exist() {
241 run_async(async {
242 let unhealthy = upstream("http://backend-unhealthy");
243 let healthy = upstream("http://backend-healthy");
244 let mut proxy_to_vector = vec![(unhealthy.clone(), None, None), (healthy.clone(), None, None)];
245
246 let failed_backends = RwLock::new(TtlCache::new(Duration::from_secs(60)));
247 {
248 let mut failed_backends_write = failed_backends.write().await;
249 failed_backends_write.insert(unhealthy, 4);
250 }
251
252 let algorithm = LoadBalancerAlgorithmInner::RoundRobin(Arc::new(AtomicUsize::new(0)));
253 let selected = determine_proxy_to(&mut proxy_to_vector, &failed_backends, true, 3, &algorithm).await;
254
255 assert!(selected.is_some());
256 let (selected_upstream, _, _) = selected.expect("a backend should be selected");
257 assert!(selected_upstream == healthy);
258 });
259 }
260}