ferron_common/util/
send_rw_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use async_channel::{Receiver, Sender};
5use bytes::{Bytes, BytesMut};
6use futures_util::{Sink, Stream};
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio_util::sync::CancellationToken;
9
10const MAX_BUFFER_SIZE: usize = 16384;
11const MAX_READ_CHANNEL_CAPACITY: usize = 2;
12const MAX_WRITE_CHANNEL_CAPACITY: usize = 2;
13
14/// A wrapper over struct implementing Tokio's `AsyncRead` and `AsyncWrite` (no need for struct to be `Send`) that implements `Stream` and `Sink` trait.
15#[allow(clippy::type_complexity)]
16pub struct SendRwStream {
17  rx: Pin<Box<Receiver<Result<Bytes, std::io::Error>>>>,
18  tx: Pin<Box<dyn Sink<Bytes, Error = std::io::Error> + Send>>,
19  read_cancel: CancellationToken,
20  write_cancel: CancellationToken,
21}
22
23impl SendRwStream {
24  /// Creates a new stream and sink from a struct implementing Tokio's `AsyncRead` and `AsyncWrite`
25  pub fn new(stream: impl AsyncRead + AsyncWrite + Unpin + 'static) -> Self {
26    let (inner_tx, rx) = async_channel::bounded(MAX_READ_CHANNEL_CAPACITY);
27    let (tx, inner_rx) = async_channel::bounded(MAX_WRITE_CHANNEL_CAPACITY);
28    let (mut reader, mut writer) = tokio::io::split(stream);
29    let read_cancel = CancellationToken::new();
30    let write_cancel = CancellationToken::new();
31    let read_cancel_clone = read_cancel.clone();
32    let write_cancel_clone = write_cancel.clone();
33    monoio::spawn(async move {
34      loop {
35        let buffer_sz = MAX_BUFFER_SIZE;
36        if buffer_sz == 0 {
37          break;
38        }
39        let mut buffer = BytesMut::with_capacity(buffer_sz);
40        let io_result = monoio::select! {
41          biased;
42
43          _ = read_cancel_clone.cancelled() => {
44            break;
45          }
46          result = reader.read_buf(&mut buffer) => {
47            result
48          }
49        };
50        if let Ok(n) = io_result.as_ref() {
51          if n == &0 {
52            break;
53          }
54        }
55        let is_err = io_result.is_err();
56        if inner_tx
57          .send(io_result.map(move |n| {
58            buffer.truncate(n);
59            buffer.freeze()
60          }))
61          .await
62          .is_err()
63        {
64          return;
65        }
66        if is_err {
67          break;
68        }
69      }
70    });
71    monoio::spawn(async move {
72      loop {
73        let rx_read_result = monoio::select! {
74          biased;
75
76          result = inner_rx.recv() => {
77            result
78          }
79          _ = write_cancel_clone.cancelled() => {
80            break;
81          }
82        };
83        let mut bytes = match rx_read_result {
84          Ok(bytes) => bytes,
85          Err(_) => {
86            // `inner_rx` is closed, but not dropped, shutting down the writer
87            writer.shutdown().await.unwrap_or_default();
88            break;
89          }
90        };
91        if writer.write_all_buf(&mut bytes).await.is_err() {
92          break;
93        }
94        if writer.flush().await.is_err() {
95          break;
96        }
97      }
98      inner_rx.close();
99    });
100    let tx = futures_util::sink::unfold(SenderWrap { inner: tx }, async move |tx, data: Bytes| {
101      tx.send(data).await.map_err(std::io::Error::other).map(|_| tx)
102    });
103    Self {
104      rx: Box::pin(rx),
105      tx: Box::pin(tx),
106      read_cancel,
107      write_cancel,
108    }
109  }
110}
111
112impl Sink<Bytes> for SendRwStream {
113  type Error = std::io::Error;
114
115  fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116    Pin::new(&mut self.tx).poll_close(cx)
117  }
118
119  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120    Pin::new(&mut self.tx).poll_flush(cx)
121  }
122
123  fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124    Pin::new(&mut self.tx).poll_ready(cx)
125  }
126
127  fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
128    Pin::new(&mut self.tx).start_send(item)
129  }
130}
131
132impl Stream for SendRwStream {
133  type Item = Result<Bytes, std::io::Error>;
134
135  fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136    Pin::new(&mut self.rx).poll_next(cx)
137  }
138}
139
140impl Drop for SendRwStream {
141  fn drop(&mut self) {
142    self.rx.close();
143    self.read_cancel.cancel();
144    self.write_cancel.cancel();
145  }
146}
147
148struct SenderWrap<T> {
149  inner: Sender<T>,
150}
151
152impl<T> SenderWrap<T> {
153  fn send(&self, data: T) -> async_channel::Send<'_, T> {
154    self.inner.send(data)
155  }
156}
157
158impl<T> Drop for SenderWrap<T> {
159  fn drop(&mut self) {
160    self.inner.close();
161  }
162}