1use core::num::NonZeroU32;
33use core::time::Duration;
34
35use spin::Mutex;
36
37use super::LimitError;
38
39#[cfg(test)]
40use std::sync::{Mutex as StdMutex, MutexGuard as StdMutexGuard};
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
51pub struct ExecutionTimerConfig {
52 pub limit: Duration,
54 pub check_interval: NonZeroU32,
56}
57
58#[derive(Debug)]
60pub struct ExecutionTimer {
61 config: Option<ExecutionTimerConfig>,
62 start: Option<Duration>,
63 accumulated_units: u32,
64 last_elapsed: Duration,
65}
66
67pub trait TimeSource: Send + Sync {
69 fn now(&self) -> Option<Duration>;
71}
72
73#[cfg(feature = "std")]
74#[derive(Debug)]
75struct StdTimeSource;
76
77#[cfg(feature = "std")]
78impl StdTimeSource {
79 const fn new() -> Self {
80 Self
81 }
82}
83
84#[cfg(feature = "std")]
85impl TimeSource for StdTimeSource {
86 fn now(&self) -> Option<Duration> {
87 use std::sync::OnceLock;
88
89 static ANCHOR: OnceLock<std::time::Instant> = OnceLock::new();
90 let anchor = ANCHOR.get_or_init(std::time::Instant::now);
91 Some(anchor.elapsed())
92 }
93}
94
95#[cfg(feature = "std")]
96static STD_TIME_SOURCE: StdTimeSource = StdTimeSource::new();
97
98#[cfg(any(test, not(feature = "std")))]
99static TIME_SOURCE_OVERRIDE: Mutex<Option<&'static dyn TimeSource>> = Mutex::new(None);
100
101#[cfg(any(test, not(feature = "std")))]
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum TimeSourceRegistrationError {
104 AlreadySet,
105}
106
107#[cfg(any(test, not(feature = "std")))]
108impl core::fmt::Display for TimeSourceRegistrationError {
109 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
110 match self {
111 Self::AlreadySet => f.write_str("time source already configured"),
112 }
113 }
114}
115
116#[cfg(any(test, not(feature = "std")))]
117impl core::error::Error for TimeSourceRegistrationError {}
118
119static FALLBACK_EXECUTION_TIMER_CONFIG: Mutex<Option<ExecutionTimerConfig>> = Mutex::new(None);
120
121#[cfg(test)]
122static LIMITS_TEST_LOCK: StdMutex<()> = StdMutex::new(());
123
124#[cfg(test)]
125pub fn acquire_limits_test_lock() -> StdMutexGuard<'static, ()> {
126 LIMITS_TEST_LOCK
127 .lock()
128 .unwrap_or_else(|poisoned| poisoned.into_inner())
129}
130
131pub fn monotonic_now() -> Option<Duration> {
133 #[cfg(any(test, not(feature = "std")))]
134 if let Some(source) = {
137 let guard = TIME_SOURCE_OVERRIDE.lock();
138 *guard
139 } {
140 if let Some(duration) = source.now() {
141 return Some(duration);
142 }
143 }
144
145 #[cfg(feature = "std")]
146 {
147 STD_TIME_SOURCE.now()
148 }
149
150 #[cfg(not(feature = "std"))]
151 {
152 None
153 }
154}
155
156#[cfg(any(test, not(feature = "std")))]
157pub fn set_time_source(source: &'static dyn TimeSource) -> Result<(), TimeSourceRegistrationError> {
158 let mut slot = TIME_SOURCE_OVERRIDE.lock();
159 if slot.is_some() {
160 Err(TimeSourceRegistrationError::AlreadySet)
161 } else {
162 *slot = Some(source);
163 Ok(())
164 }
165}
166
167pub fn set_fallback_execution_timer_config(config: Option<ExecutionTimerConfig>) {
189 *FALLBACK_EXECUTION_TIMER_CONFIG.lock() = config;
190}
191
192pub fn fallback_execution_timer_config() -> Option<ExecutionTimerConfig> {
203 let guard = FALLBACK_EXECUTION_TIMER_CONFIG.lock();
204 guard.as_ref().copied()
205}
206
207impl ExecutionTimer {
208 pub const fn new(config: Option<ExecutionTimerConfig>) -> Self {
210 Self {
211 config,
212 start: None,
213 accumulated_units: 0,
214 last_elapsed: Duration::ZERO,
215 }
216 }
217
218 pub const fn reset(&mut self) {
220 self.start = None;
221 self.accumulated_units = 0;
222 self.last_elapsed = Duration::ZERO;
223 }
224
225 pub const fn start(&mut self, now: Duration) {
227 self.start = Some(now);
228 self.accumulated_units = 0;
229 self.last_elapsed = Duration::ZERO;
230 }
231
232 pub const fn config(&self) -> Option<ExecutionTimerConfig> {
234 self.config
235 }
236
237 pub const fn limit(&self) -> Option<Duration> {
239 match self.config {
240 Some(config) => Some(config.limit),
241 None => None,
242 }
243 }
244
245 pub const fn last_elapsed(&self) -> Duration {
247 self.last_elapsed
248 }
249
250 pub fn tick(&mut self, work_units: u32, now: Duration) -> Result<(), LimitError> {
252 let Some(config) = self.config else {
253 return Ok(());
254 };
255 self.accumulated_units = self.accumulated_units.saturating_add(work_units);
256 if self.accumulated_units < config.check_interval.get() {
257 return Ok(());
258 }
259
260 let interval = config.check_interval.get();
262 self.accumulated_units %= interval;
263 self.check_now(now)
264 }
265
266 pub fn check_now(&mut self, now: Duration) -> Result<(), LimitError> {
268 let Some(config) = self.config else {
269 return Ok(());
270 };
271 let Some(start) = self.start else {
272 return Ok(());
273 };
274
275 let elapsed = now.checked_sub(start).unwrap_or(Duration::ZERO);
276 self.last_elapsed = elapsed;
277 if elapsed > config.limit {
278 return Err(LimitError::TimeLimitExceeded {
279 elapsed,
280 limit: config.limit,
281 });
282 }
283 Ok(())
284 }
285
286 pub fn elapsed(&self, now: Duration) -> Option<Duration> {
288 let start = self.start?;
289 Some(now.checked_sub(start).unwrap_or(Duration::ZERO))
290 }
291
292 pub const fn resume_from_elapsed(&mut self, now: Duration, elapsed: Duration) {
295 if self.config.is_none() {
296 return;
297 }
298
299 self.start = Some(now.saturating_sub(elapsed));
300 self.last_elapsed = elapsed;
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use core::num::NonZeroU32;
308 use core::sync::atomic::{AtomicU64, Ordering};
309 use core::time::Duration;
310
311 fn nz(value: u32) -> NonZeroU32 {
312 NonZeroU32::new(value).unwrap_or(NonZeroU32::MIN)
313 }
314
315 #[test]
316 fn tick_defers_checks_until_interval_is_reached() {
317 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
318 limit: Duration::from_millis(100),
319 check_interval: nz(4),
320 }));
321
322 timer.start(Duration::from_millis(0));
323
324 for step in 1..4 {
325 let now = Duration::from_millis((step * 10) as u64);
326 let result = timer.tick(1, now);
327 assert_eq!(result, Ok(()), "tick before reaching interval must succeed");
328 assert_eq!(timer.last_elapsed(), Duration::ZERO);
329 }
330
331 let result = timer.tick(1, Duration::from_millis(40));
332 assert_eq!(result, Ok(()), "tick at interval boundary must succeed");
333 assert_eq!(timer.last_elapsed(), Duration::from_millis(40));
334 }
335
336 #[test]
337 fn check_now_reports_limit_exceeded() {
338 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
339 limit: Duration::from_millis(25),
340 check_interval: nz(1),
341 }));
342
343 timer.start(Duration::from_millis(0));
344 assert_eq!(
345 timer.tick(1, Duration::from_millis(10)),
346 Ok(()),
347 "tick before limit breach must succeed"
348 );
349
350 let result = timer.check_now(Duration::from_millis(30));
351 assert!(matches!(&result, Err(LimitError::TimeLimitExceeded { .. })));
352
353 if let Err(LimitError::TimeLimitExceeded { elapsed, limit }) = result {
354 assert!(elapsed > limit);
355 assert_eq!(limit, Duration::from_millis(25));
356 }
357 }
358
359 #[test]
360 fn tick_reports_limit_exceeded() {
361 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
362 limit: Duration::from_millis(30),
363 check_interval: nz(2),
364 }));
365
366 timer.start(Duration::from_millis(0));
367 assert_eq!(
368 timer.tick(1, Duration::from_millis(10)),
369 Ok(()),
370 "initial tick must succeed"
371 );
372
373 let result = timer.tick(1, Duration::from_millis(35));
374 assert!(matches!(&result, Err(LimitError::TimeLimitExceeded { .. })));
375
376 if let Err(LimitError::TimeLimitExceeded { elapsed, limit }) = result {
377 assert!(elapsed > limit);
378 assert_eq!(limit, Duration::from_millis(30));
379 assert_eq!(timer.last_elapsed(), elapsed);
380 }
381 }
382
383 #[test]
384 fn tick_before_start_is_noop() {
385 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
386 limit: Duration::from_secs(1),
387 check_interval: nz(1),
388 }));
389
390 let result = timer.tick(1, Duration::from_millis(100));
391 assert_eq!(result, Ok(()), "tick before start should be ignored");
392 assert_eq!(timer.last_elapsed(), Duration::ZERO);
393 assert!(timer.elapsed(Duration::from_millis(200)).is_none());
394 }
395
396 #[test]
397 fn check_now_allows_elapsed_equal_to_limit() {
398 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
399 limit: Duration::from_millis(50),
400 check_interval: nz(1),
401 }));
402
403 timer.start(Duration::from_millis(0));
404 assert_eq!(
405 timer.tick(1, Duration::from_millis(30)),
406 Ok(()),
407 "tick prior to equality check must succeed"
408 );
409 let result = timer.check_now(Duration::from_millis(50));
410 assert_eq!(result, Ok(()), "elapsed equal to limit must not fail");
411 assert_eq!(timer.last_elapsed(), Duration::from_millis(50));
412 }
413
414 #[test]
415 fn tick_is_noop_when_limit_disabled() {
416 let mut timer = ExecutionTimer::new(None);
417
418 timer.start(Duration::from_millis(0));
419
420 for step in 0..8 {
421 let now = Duration::from_millis((step + 1) as u64);
422 assert_eq!(
423 timer.tick(1, now),
424 Ok(()),
425 "ticks with disabled limit must succeed"
426 );
427 }
428
429 assert_eq!(timer.last_elapsed(), Duration::ZERO);
430 }
431
432 #[test]
433 fn check_now_is_noop_before_start() {
434 let mut timer = ExecutionTimer::new(None);
435 let result = timer.check_now(Duration::from_secs(1));
436 assert_eq!(result, Ok(()), "check before start must be ignored");
437 assert!(timer.elapsed(Duration::from_secs(2)).is_none());
438 }
439
440 #[test]
441 fn elapsed_reports_offset_from_start() {
442 let mut timer = ExecutionTimer::new(None);
443 timer.start(Duration::from_millis(5));
444 let elapsed = timer.elapsed(Duration::from_millis(20));
445 assert_eq!(elapsed, Some(Duration::from_millis(15)));
446 }
447
448 #[test]
449 fn monotonic_now_uses_override_when_present() {
450 static TEST_TIME: AtomicU64 = AtomicU64::new(0);
451
452 struct TestSource;
453
454 impl TimeSource for TestSource {
455 fn now(&self) -> Option<Duration> {
456 Some(Duration::from_nanos(TEST_TIME.load(Ordering::Relaxed)))
457 }
458 }
459
460 static SOURCE: TestSource = TestSource;
461
462 let _suite_guard = super::acquire_limits_test_lock();
463
464 let mut slot = super::TIME_SOURCE_OVERRIDE.lock();
465 let previous = (*slot).replace(&SOURCE);
466 drop(slot);
467
468 TEST_TIME.store(123_000_000, Ordering::Relaxed);
469 assert_eq!(monotonic_now(), Some(Duration::from_nanos(123_000_000)));
470
471 let mut slot = super::TIME_SOURCE_OVERRIDE.lock();
472 *slot = previous;
473 }
474}