ferron_common/util/
send_rw_stream.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use async_channel::{Receiver, Sender};
5use bytes::{Bytes, BytesMut};
6use futures_util::{Sink, Stream};
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio_util::sync::CancellationToken;
9
10const MAX_BUFFER_SIZE: usize = 16384;
11const MAX_READ_CHANNEL_CAPACITY: usize = 2;
12const MAX_WRITE_CHANNEL_CAPACITY: usize = 2;
13
14#[allow(clippy::type_complexity)]
16pub struct SendRwStream {
17 rx: Pin<Box<Receiver<Result<Bytes, std::io::Error>>>>,
18 tx: Pin<Box<dyn Sink<Bytes, Error = std::io::Error> + Send>>,
19 read_cancel: CancellationToken,
20 write_cancel: CancellationToken,
21}
22
23impl SendRwStream {
24 pub fn new(stream: impl AsyncRead + AsyncWrite + Unpin + 'static) -> Self {
26 let (inner_tx, rx) = async_channel::bounded(MAX_READ_CHANNEL_CAPACITY);
27 let (tx, inner_rx) = async_channel::bounded(MAX_WRITE_CHANNEL_CAPACITY);
28 let (mut reader, mut writer) = tokio::io::split(stream);
29 let read_cancel = CancellationToken::new();
30 let write_cancel = CancellationToken::new();
31 let read_cancel_clone = read_cancel.clone();
32 let write_cancel_clone = write_cancel.clone();
33 monoio::spawn(async move {
34 loop {
35 let buffer_sz = MAX_BUFFER_SIZE;
36 if buffer_sz == 0 {
37 break;
38 }
39 let mut buffer = BytesMut::with_capacity(buffer_sz);
40 let io_result = monoio::select! {
41 biased;
42
43 _ = read_cancel_clone.cancelled() => {
44 break;
45 }
46 result = reader.read_buf(&mut buffer) => {
47 result
48 }
49 };
50 if let Ok(n) = io_result.as_ref() {
51 if n == &0 {
52 break;
53 }
54 }
55 let is_err = io_result.is_err();
56 if inner_tx
57 .send(io_result.map(move |n| {
58 buffer.truncate(n);
59 buffer.freeze()
60 }))
61 .await
62 .is_err()
63 {
64 return;
65 }
66 if is_err {
67 break;
68 }
69 }
70 });
71 monoio::spawn(async move {
72 loop {
73 let rx_read_result = monoio::select! {
74 biased;
75
76 result = inner_rx.recv() => {
77 result
78 }
79 _ = write_cancel_clone.cancelled() => {
80 break;
81 }
82 };
83 let mut bytes = match rx_read_result {
84 Ok(bytes) => bytes,
85 Err(_) => {
86 writer.shutdown().await.unwrap_or_default();
88 break;
89 }
90 };
91 if writer.write_all_buf(&mut bytes).await.is_err() {
92 break;
93 }
94 if writer.flush().await.is_err() {
95 break;
96 }
97 }
98 inner_rx.close();
99 });
100 let tx = futures_util::sink::unfold(SenderWrap { inner: tx }, async move |tx, data: Bytes| {
101 tx.send(data).await.map_err(std::io::Error::other).map(|_| tx)
102 });
103 Self {
104 rx: Box::pin(rx),
105 tx: Box::pin(tx),
106 read_cancel,
107 write_cancel,
108 }
109 }
110}
111
112impl Sink<Bytes> for SendRwStream {
113 type Error = std::io::Error;
114
115 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 Pin::new(&mut self.tx).poll_close(cx)
117 }
118
119 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 Pin::new(&mut self.tx).poll_flush(cx)
121 }
122
123 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124 Pin::new(&mut self.tx).poll_ready(cx)
125 }
126
127 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
128 Pin::new(&mut self.tx).start_send(item)
129 }
130}
131
132impl Stream for SendRwStream {
133 type Item = Result<Bytes, std::io::Error>;
134
135 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136 Pin::new(&mut self.rx).poll_next(cx)
137 }
138}
139
140impl Drop for SendRwStream {
141 fn drop(&mut self) {
142 self.rx.close();
143 self.read_cancel.cancel();
144 self.write_cancel.cancel();
145 }
146}
147
148struct SenderWrap<T> {
149 inner: Sender<T>,
150}
151
152impl<T> SenderWrap<T> {
153 fn send(&self, data: T) -> async_channel::Send<'_, T> {
154 self.inner.send(data)
155 }
156}
157
158impl<T> Drop for SenderWrap<T> {
159 fn drop(&mut self) {
160 self.inner.close();
161 }
162}