tower/util/
rng.rs

1//! [PRNG] utilities for tower middleware.
2//!
3//! This module provides a generic [`Rng`] trait and a [`HasherRng`] that
4//! implements the trait based on [`RandomState`] or any other [`Hasher`].
5//!
6//! These utilities replace tower's internal usage of `rand` with these smaller,
7//! more lightweight methods. Most of the implementations are extracted from
8//! their corresponding `rand` implementations.
9//!
10//! [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
11
12use std::{
13    collections::hash_map::RandomState,
14    hash::{BuildHasher, Hasher},
15    ops::Range,
16};
17
18/// A simple [PRNG] trait for use within tower middleware.
19///
20/// [PRNG]: https://en.wikipedia.org/wiki/Pseudorandom_number_generator
21pub trait Rng {
22    /// Generate a random [`u64`].
23    fn next_u64(&mut self) -> u64;
24
25    /// Generate a random [`f64`] between `[0, 1)`.
26    fn next_f64(&mut self) -> f64 {
27        // Borrowed from:
28        // https://github.com/rust-random/rand/blob/master/src/distr/float.rs#L108
29        let float_size = std::mem::size_of::<f64>() as u32 * 8;
30        let precision = 52 + 1;
31        let scale = 1.0 / ((1u64 << precision) as f64);
32
33        let value = self.next_u64();
34        let value = value >> (float_size - precision);
35
36        scale * value as f64
37    }
38
39    /// Randomly pick a value within the range.
40    ///
41    /// # Panic
42    ///
43    /// - If `range.start >= range.end` this will panic in debug mode.
44    fn next_range(&mut self, range: Range<u64>) -> u64 {
45        debug_assert!(
46            range.start < range.end,
47            "The range start must be smaller than the end"
48        );
49        let start = range.start;
50        let end = range.end;
51
52        let range = end - start;
53
54        let n = self.next_u64();
55
56        (n % range) + start
57    }
58}
59
60impl<R: Rng + ?Sized> Rng for Box<R> {
61    fn next_u64(&mut self) -> u64 {
62        (**self).next_u64()
63    }
64}
65
66/// A [`Rng`] implementation that uses a [`Hasher`] to generate the random
67/// values. The implementation uses an internal counter to pass to the hasher
68/// for each iteration of [`Rng::next_u64`].
69///
70/// # Default
71///
72/// This hasher has a default type of [`RandomState`] which just uses the
73/// libstd method of getting a random u64.
74#[derive(Clone, Debug)]
75pub struct HasherRng<H = RandomState> {
76    hasher: H,
77    counter: u64,
78}
79
80impl HasherRng {
81    /// Create a new default [`HasherRng`].
82    pub fn new() -> Self {
83        HasherRng::default()
84    }
85}
86
87impl Default for HasherRng {
88    fn default() -> Self {
89        HasherRng::with_hasher(RandomState::default())
90    }
91}
92
93impl<H> HasherRng<H> {
94    /// Create a new [`HasherRng`] with the provided hasher.
95    pub fn with_hasher(hasher: H) -> Self {
96        HasherRng { hasher, counter: 0 }
97    }
98}
99
100impl<H> Rng for HasherRng<H>
101where
102    H: BuildHasher,
103{
104    fn next_u64(&mut self) -> u64 {
105        let mut hasher = self.hasher.build_hasher();
106        hasher.write_u64(self.counter);
107        self.counter = self.counter.wrapping_add(1);
108        hasher.finish()
109    }
110}
111
112/// A sampler modified from the Rand implementation for use internally for the balance middleware.
113///
114/// It's an implementation of Floyd's combination algorithm with amount fixed at 2. This uses no allocated
115/// memory and finishes in constant time (only 2 random calls).
116///
117/// ref: This was borrowed and modified from the following Rand implementation
118/// https://github.com/rust-random/rand/blob/b73640705d6714509f8ceccc49e8df996fa19f51/src/seq/index.rs#L375-L411
119#[cfg(feature = "balance")]
120pub(crate) fn sample_floyd2<R: Rng>(rng: &mut R, length: u64) -> [u64; 2] {
121    debug_assert!(2 <= length);
122    let aidx = rng.next_range(0..length - 1);
123    let bidx = rng.next_range(0..length);
124    let aidx = if aidx == bidx { length - 1 } else { aidx };
125    [aidx, bidx]
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use quickcheck::*;
132
133    quickcheck! {
134        fn next_f64(counter: u64) -> TestResult {
135            let mut rng = HasherRng {
136                counter,
137                ..HasherRng::default()
138            };
139            let n = rng.next_f64();
140
141            TestResult::from_bool((0.0..1.0).contains(&n))
142        }
143
144        fn next_range(counter: u64, range: Range<u64>) -> TestResult {
145            if range.start >= range.end{
146                return TestResult::discard();
147            }
148
149            let mut rng = HasherRng {
150                counter,
151                ..HasherRng::default()
152            };
153
154            let n = rng.next_range(range.clone());
155
156            TestResult::from_bool(n >= range.start && (n < range.end || range.start == range.end))
157        }
158
159        fn sample_floyd2(counter: u64, length: u64) -> TestResult {
160            if !(2..=256).contains(&length) {
161                return TestResult::discard();
162            }
163
164            let mut rng = HasherRng {
165                counter,
166                ..HasherRng::default()
167            };
168
169            let [a, b] = super::sample_floyd2(&mut rng, length);
170
171            if a >= length || b >= length || a == b {
172                return TestResult::failed();
173            }
174
175            TestResult::passed()
176        }
177    }
178
179    #[test]
180    fn sample_inplace_boundaries() {
181        let mut r = HasherRng::default();
182        match super::sample_floyd2(&mut r, 2) {
183            [0, 1] | [1, 0] => (),
184            array => panic!("unexpected inplace boundaries: {:?}", array),
185        }
186    }
187}