ferron_common/util/
send_rw_stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use async_channel::{Receiver, Sender};
5use bytes::{Bytes, BytesMut};
6use futures_util::{Sink, Stream};
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio_util::sync::CancellationToken;
9
10const MAX_BUFFER_SIZE: usize = 16384;
11const MAX_READ_CHANNEL_CAPACITY: usize = 2;
12const MAX_WRITE_CHANNEL_CAPACITY: usize = 2;
13
14/// A wrapper over struct implementing Tokio's `AsyncRead` and `AsyncWrite` (no need for struct to be `Send`) that implements `Stream` and `Sink` trait.
15pub struct SendRwStream {
16  rx: Pin<Box<Receiver<Result<Bytes, std::io::Error>>>>,
17  tx: Pin<Box<dyn Sink<Bytes, Error = std::io::Error> + Send>>,
18  read_cancel: CancellationToken,
19  write_cancel: CancellationToken,
20}
21
22impl SendRwStream {
23  /// Creates a new stream and sink from a struct implementing Tokio's `AsyncRead` and `AsyncWrite`
24  pub fn new(stream: impl AsyncRead + AsyncWrite + Unpin + 'static) -> Self {
25    let (inner_tx, rx) = async_channel::bounded(MAX_READ_CHANNEL_CAPACITY);
26    let (tx, inner_rx) = async_channel::bounded(MAX_WRITE_CHANNEL_CAPACITY);
27    let (mut reader, mut writer) = tokio::io::split(stream);
28    let read_cancel = CancellationToken::new();
29    let write_cancel = CancellationToken::new();
30    let read_cancel_clone = read_cancel.clone();
31    let write_cancel_clone = write_cancel.clone();
32    monoio::spawn(async move {
33      loop {
34        let buffer_sz = MAX_BUFFER_SIZE;
35        if buffer_sz == 0 {
36          break;
37        }
38        let mut buffer = BytesMut::with_capacity(buffer_sz);
39        let io_result = monoio::select! {
40          biased;
41
42          _ = read_cancel_clone.cancelled() => {
43            break;
44          }
45          result = reader.read_buf(&mut buffer) => {
46            result
47          }
48        };
49        if let Ok(n) = io_result.as_ref() {
50          if n == &0 {
51            break;
52          }
53        }
54        let is_err = io_result.is_err();
55        if inner_tx
56          .send(io_result.map(move |n| {
57            buffer.truncate(n);
58            buffer.freeze()
59          }))
60          .await
61          .is_err()
62        {
63          return;
64        }
65        if is_err {
66          break;
67        }
68      }
69    });
70    monoio::spawn(async move {
71      loop {
72        let rx_read_result = monoio::select! {
73          biased;
74
75          result = inner_rx.recv() => {
76            result
77          }
78          _ = write_cancel_clone.cancelled() => {
79            break;
80          }
81        };
82        let mut bytes = match rx_read_result {
83          Ok(bytes) => bytes,
84          Err(_) => {
85            // `inner_rx` is closed, but not dropped, shutting down the writer
86            writer.shutdown().await.unwrap_or_default();
87            break;
88          }
89        };
90        if writer.write_all_buf(&mut bytes).await.is_err() {
91          break;
92        }
93        if writer.flush().await.is_err() {
94          break;
95        }
96      }
97      inner_rx.close();
98    });
99    let tx = futures_util::sink::unfold(SenderWrap { inner: tx }, async move |tx, data: Bytes| {
100      tx.send(data).await.map_err(std::io::Error::other).map(|_| tx)
101    });
102    Self {
103      rx: Box::pin(rx),
104      tx: Box::pin(tx),
105      read_cancel,
106      write_cancel,
107    }
108  }
109}
110
111impl Sink<Bytes> for SendRwStream {
112  type Error = std::io::Error;
113
114  #[inline]
115  fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116    Pin::new(&mut self.tx).poll_close(cx)
117  }
118
119  #[inline]
120  fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121    Pin::new(&mut self.tx).poll_flush(cx)
122  }
123
124  #[inline]
125  fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126    Pin::new(&mut self.tx).poll_ready(cx)
127  }
128
129  #[inline]
130  fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
131    Pin::new(&mut self.tx).start_send(item)
132  }
133}
134
135impl Stream for SendRwStream {
136  type Item = Result<Bytes, std::io::Error>;
137
138  #[inline]
139  fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140    Pin::new(&mut self.rx).poll_next(cx)
141  }
142}
143
144impl Drop for SendRwStream {
145  #[inline]
146  fn drop(&mut self) {
147    self.rx.close();
148    self.read_cancel.cancel();
149    self.write_cancel.cancel();
150  }
151}
152
153struct SenderWrap<T> {
154  inner: Sender<T>,
155}
156
157impl<T> SenderWrap<T> {
158  #[inline]
159  fn send(&self, data: T) -> async_channel::Send<'_, T> {
160    self.inner.send(data)
161  }
162}
163
164impl<T> Drop for SenderWrap<T> {
165  #[inline]
166  fn drop(&mut self) {
167    self.inner.close();
168  }
169}