tokio/task/task_local.rs
1use pin_project_lite::pin_project;
2use std::cell::RefCell;
3use std::error::Error;
4use std::future::Future;
5use std::marker::PhantomPinned;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{fmt, mem, thread};
9
10/// Declares a new task-local key of type [`tokio::task::LocalKey`].
11///
12/// # Syntax
13///
14/// The macro wraps any number of static declarations and makes them local to the current task.
15/// Publicity and attributes for each static is preserved. For example:
16///
17/// # Examples
18///
19/// ```
20/// # use tokio::task_local;
21/// task_local! {
22/// pub static ONE: u32;
23///
24/// #[allow(unused)]
25/// static TWO: f32;
26/// }
27/// # fn main() {}
28/// ```
29///
30/// See [`LocalKey` documentation][`tokio::task::LocalKey`] for more
31/// information.
32///
33/// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey
34#[macro_export]
35#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
36macro_rules! task_local {
37 // empty (base case for the recursion)
38 () => {};
39
40 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
41 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
42 $crate::task_local!($($rest)*);
43 };
44
45 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
46 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
47 }
48}
49
50#[doc(hidden)]
51#[macro_export]
52macro_rules! __task_local_inner {
53 ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
54 $(#[$attr])*
55 $vis static $name: $crate::task::LocalKey<$t> = {
56 std::thread_local! {
57 static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
58 }
59
60 $crate::task::LocalKey { inner: __KEY }
61 };
62 };
63}
64
65/// A key for task-local data.
66///
67/// This type is generated by the [`task_local!`] macro.
68///
69/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
70/// _not_ lazily initialize the value on first access. Instead, the
71/// value is first initialized when the future containing
72/// the task-local is first polled by a futures executor, like Tokio.
73///
74/// # Examples
75///
76/// ```
77/// # async fn dox() {
78/// tokio::task_local! {
79/// static NUMBER: u32;
80/// }
81///
82/// NUMBER.scope(1, async move {
83/// assert_eq!(NUMBER.get(), 1);
84/// }).await;
85///
86/// NUMBER.scope(2, async move {
87/// assert_eq!(NUMBER.get(), 2);
88///
89/// NUMBER.scope(3, async move {
90/// assert_eq!(NUMBER.get(), 3);
91/// }).await;
92/// }).await;
93/// # }
94/// ```
95///
96/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
97/// [`task_local!`]: ../macro.task_local.html
98#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
99pub struct LocalKey<T: 'static> {
100 #[doc(hidden)]
101 pub inner: thread::LocalKey<RefCell<Option<T>>>,
102}
103
104impl<T: 'static> LocalKey<T> {
105 /// Sets a value `T` as the task-local value for the future `F`.
106 ///
107 /// On completion of `scope`, the task-local will be dropped.
108 ///
109 /// ### Panics
110 ///
111 /// If you poll the returned future inside a call to [`with`] or
112 /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
113 ///
114 /// ### Examples
115 ///
116 /// ```
117 /// # async fn dox() {
118 /// tokio::task_local! {
119 /// static NUMBER: u32;
120 /// }
121 ///
122 /// NUMBER.scope(1, async move {
123 /// println!("task local value: {}", NUMBER.get());
124 /// }).await;
125 /// # }
126 /// ```
127 ///
128 /// [`with`]: fn@Self::with
129 /// [`try_with`]: fn@Self::try_with
130 pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
131 where
132 F: Future,
133 {
134 TaskLocalFuture {
135 local: self,
136 slot: Some(value),
137 future: Some(f),
138 _pinned: PhantomPinned,
139 }
140 }
141
142 /// Sets a value `T` as the task-local value for the closure `F`.
143 ///
144 /// On completion of `sync_scope`, the task-local will be dropped.
145 ///
146 /// ### Panics
147 ///
148 /// This method panics if called inside a call to [`with`] or [`try_with`]
149 /// on the same `LocalKey`.
150 ///
151 /// ### Examples
152 ///
153 /// ```
154 /// # async fn dox() {
155 /// tokio::task_local! {
156 /// static NUMBER: u32;
157 /// }
158 ///
159 /// NUMBER.sync_scope(1, || {
160 /// println!("task local value: {}", NUMBER.get());
161 /// });
162 /// # }
163 /// ```
164 ///
165 /// [`with`]: fn@Self::with
166 /// [`try_with`]: fn@Self::try_with
167 #[track_caller]
168 pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
169 where
170 F: FnOnce() -> R,
171 {
172 let mut value = Some(value);
173 match self.scope_inner(&mut value, f) {
174 Ok(res) => res,
175 Err(err) => err.panic(),
176 }
177 }
178
179 fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
180 where
181 F: FnOnce() -> R,
182 {
183 struct Guard<'a, T: 'static> {
184 local: &'static LocalKey<T>,
185 slot: &'a mut Option<T>,
186 }
187
188 impl<'a, T: 'static> Drop for Guard<'a, T> {
189 fn drop(&mut self) {
190 // This should not panic.
191 //
192 // We know that the RefCell was not borrowed before the call to
193 // `scope_inner`, so the only way for this to panic is if the
194 // closure has created but not destroyed a RefCell guard.
195 // However, we never give user-code access to the guards, so
196 // there's no way for user-code to forget to destroy a guard.
197 //
198 // The call to `with` also should not panic, since the
199 // thread-local wasn't destroyed when we first called
200 // `scope_inner`, and it shouldn't have gotten destroyed since
201 // then.
202 self.local.inner.with(|inner| {
203 let mut ref_mut = inner.borrow_mut();
204 mem::swap(self.slot, &mut *ref_mut);
205 });
206 }
207 }
208
209 self.inner.try_with(|inner| {
210 inner
211 .try_borrow_mut()
212 .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
213 })??;
214
215 let guard = Guard { local: self, slot };
216
217 let res = f();
218
219 drop(guard);
220
221 Ok(res)
222 }
223
224 /// Accesses the current task-local and runs the provided closure.
225 ///
226 /// # Panics
227 ///
228 /// This function will panic if the task local doesn't have a value set.
229 #[track_caller]
230 pub fn with<F, R>(&'static self, f: F) -> R
231 where
232 F: FnOnce(&T) -> R,
233 {
234 match self.try_with(f) {
235 Ok(res) => res,
236 Err(_) => panic!("cannot access a task-local storage value without setting it first"),
237 }
238 }
239
240 /// Accesses the current task-local and runs the provided closure.
241 ///
242 /// If the task-local with the associated key is not present, this
243 /// method will return an `AccessError`. For a panicking variant,
244 /// see `with`.
245 pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
246 where
247 F: FnOnce(&T) -> R,
248 {
249 // If called after the thread-local storing the task-local is destroyed,
250 // then we are outside of a closure where the task-local is set.
251 //
252 // Therefore, it is correct to return an AccessError if `try_with`
253 // returns an error.
254 let try_with_res = self.inner.try_with(|v| {
255 // This call to `borrow` cannot panic because no user-defined code
256 // runs while a `borrow_mut` call is active.
257 v.borrow().as_ref().map(f)
258 });
259
260 match try_with_res {
261 Ok(Some(res)) => Ok(res),
262 Ok(None) | Err(_) => Err(AccessError { _private: () }),
263 }
264 }
265}
266
267impl<T: Clone + 'static> LocalKey<T> {
268 /// Returns a copy of the task-local value
269 /// if the task-local value implements `Clone`.
270 ///
271 /// # Panics
272 ///
273 /// This function will panic if the task local doesn't have a value set.
274 #[track_caller]
275 pub fn get(&'static self) -> T {
276 self.with(|v| v.clone())
277 }
278
279 /// Returns a copy of the task-local value
280 /// if the task-local value implements `Clone`.
281 ///
282 /// If the task-local with the associated key is not present, this
283 /// method will return an `AccessError`. For a panicking variant,
284 /// see `get`.
285 pub fn try_get(&'static self) -> Result<T, AccessError> {
286 self.try_with(|v| v.clone())
287 }
288}
289
290impl<T: 'static> fmt::Debug for LocalKey<T> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 f.pad("LocalKey { .. }")
293 }
294}
295
296pin_project! {
297 /// A future that sets a value `T` of a task local for the future `F` during
298 /// its execution.
299 ///
300 /// The value of the task-local must be `'static` and will be dropped on the
301 /// completion of the future.
302 ///
303 /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
304 ///
305 /// ### Examples
306 ///
307 /// ```
308 /// # async fn dox() {
309 /// tokio::task_local! {
310 /// static NUMBER: u32;
311 /// }
312 ///
313 /// NUMBER.scope(1, async move {
314 /// println!("task local value: {}", NUMBER.get());
315 /// }).await;
316 /// # }
317 /// ```
318 pub struct TaskLocalFuture<T, F>
319 where
320 T: 'static,
321 {
322 local: &'static LocalKey<T>,
323 slot: Option<T>,
324 #[pin]
325 future: Option<F>,
326 #[pin]
327 _pinned: PhantomPinned,
328 }
329
330 impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
331 fn drop(this: Pin<&mut Self>) {
332 let this = this.project();
333 if mem::needs_drop::<F>() && this.future.is_some() {
334 // Drop the future while the task-local is set, if possible. Otherwise
335 // the future is dropped normally when the `Option<F>` field drops.
336 let mut future = this.future;
337 let _ = this.local.scope_inner(this.slot, || {
338 future.set(None);
339 });
340 }
341 }
342 }
343}
344
345impl<T, F> TaskLocalFuture<T, F>
346where
347 T: 'static,
348{
349 /// Returns the value stored in the task local by this `TaskLocalFuture`.
350 ///
351 /// The function returns:
352 ///
353 /// * `Some(T)` if the task local value exists.
354 /// * `None` if the task local value has already been taken.
355 ///
356 /// Note that this function attempts to take the task local value even if
357 /// the future has not yet completed. In that case, the value will no longer
358 /// be available via the task local after the call to `take_value`.
359 ///
360 /// # Examples
361 ///
362 /// ```
363 /// # async fn dox() {
364 /// tokio::task_local! {
365 /// static KEY: u32;
366 /// }
367 ///
368 /// let fut = KEY.scope(42, async {
369 /// // Do some async work
370 /// });
371 ///
372 /// let mut pinned = Box::pin(fut);
373 ///
374 /// // Complete the TaskLocalFuture
375 /// let _ = pinned.as_mut().await;
376 ///
377 /// // And here, we can take task local value
378 /// let value = pinned.as_mut().take_value();
379 ///
380 /// assert_eq!(value, Some(42));
381 /// # }
382 /// ```
383 pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
384 let this = self.project();
385 this.slot.take()
386 }
387}
388
389impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
390 type Output = F::Output;
391
392 #[track_caller]
393 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
394 let this = self.project();
395 let mut future_opt = this.future;
396
397 let res = this
398 .local
399 .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
400 Some(fut) => {
401 let res = fut.poll(cx);
402 if res.is_ready() {
403 future_opt.set(None);
404 }
405 Some(res)
406 }
407 None => None,
408 });
409
410 match res {
411 Ok(Some(res)) => res,
412 Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
413 Err(err) => err.panic(),
414 }
415 }
416}
417
418impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
419where
420 T: fmt::Debug,
421{
422 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
423 /// Format the Option without Some.
424 struct TransparentOption<'a, T> {
425 value: &'a Option<T>,
426 }
427 impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> {
428 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429 match self.value.as_ref() {
430 Some(value) => value.fmt(f),
431 // Hitting the None branch should not be possible.
432 None => f.pad("<missing>"),
433 }
434 }
435 }
436
437 f.debug_struct("TaskLocalFuture")
438 .field("value", &TransparentOption { value: &self.slot })
439 .finish()
440 }
441}
442
443/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
444#[derive(Clone, Copy, Eq, PartialEq)]
445pub struct AccessError {
446 _private: (),
447}
448
449impl fmt::Debug for AccessError {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 f.debug_struct("AccessError").finish()
452 }
453}
454
455impl fmt::Display for AccessError {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 fmt::Display::fmt("task-local value not set", f)
458 }
459}
460
461impl Error for AccessError {}
462
463enum ScopeInnerErr {
464 BorrowError,
465 AccessError,
466}
467
468impl ScopeInnerErr {
469 #[track_caller]
470 fn panic(&self) -> ! {
471 match self {
472 Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed"),
473 Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local"),
474 }
475 }
476}
477
478impl From<std::cell::BorrowMutError> for ScopeInnerErr {
479 fn from(_: std::cell::BorrowMutError) -> Self {
480 Self::BorrowError
481 }
482}
483
484impl From<std::thread::AccessError> for ScopeInnerErr {
485 fn from(_: std::thread::AccessError) -> Self {
486 Self::AccessError
487 }
488}