ferron_common/util/
send_rw_stream.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use async_channel::{Receiver, Sender};
5use bytes::{Bytes, BytesMut};
6use futures_util::{Sink, Stream};
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio_util::sync::CancellationToken;
9
10const MAX_BUFFER_SIZE: usize = 16384;
11const MAX_READ_CHANNEL_CAPACITY: usize = 2;
12const MAX_WRITE_CHANNEL_CAPACITY: usize = 2;
13
14pub struct SendRwStream {
16 rx: Pin<Box<Receiver<Result<Bytes, std::io::Error>>>>,
17 tx: Pin<Box<dyn Sink<Bytes, Error = std::io::Error> + Send>>,
18 read_cancel: CancellationToken,
19 write_cancel: CancellationToken,
20}
21
22impl SendRwStream {
23 pub fn new(stream: impl AsyncRead + AsyncWrite + Unpin + 'static) -> Self {
25 let (inner_tx, rx) = async_channel::bounded(MAX_READ_CHANNEL_CAPACITY);
26 let (tx, inner_rx) = async_channel::bounded(MAX_WRITE_CHANNEL_CAPACITY);
27 let (mut reader, mut writer) = tokio::io::split(stream);
28 let read_cancel = CancellationToken::new();
29 let write_cancel = CancellationToken::new();
30 let read_cancel_clone = read_cancel.clone();
31 let write_cancel_clone = write_cancel.clone();
32 monoio::spawn(async move {
33 loop {
34 let buffer_sz = MAX_BUFFER_SIZE;
35 if buffer_sz == 0 {
36 break;
37 }
38 let mut buffer = BytesMut::with_capacity(buffer_sz);
39 let io_result = monoio::select! {
40 biased;
41
42 _ = read_cancel_clone.cancelled() => {
43 break;
44 }
45 result = reader.read_buf(&mut buffer) => {
46 result
47 }
48 };
49 if let Ok(n) = io_result.as_ref() {
50 if n == &0 {
51 break;
52 }
53 }
54 let is_err = io_result.is_err();
55 if inner_tx
56 .send(io_result.map(move |n| {
57 buffer.truncate(n);
58 buffer.freeze()
59 }))
60 .await
61 .is_err()
62 {
63 return;
64 }
65 if is_err {
66 break;
67 }
68 }
69 });
70 monoio::spawn(async move {
71 loop {
72 let rx_read_result = monoio::select! {
73 biased;
74
75 result = inner_rx.recv() => {
76 result
77 }
78 _ = write_cancel_clone.cancelled() => {
79 break;
80 }
81 };
82 let mut bytes = match rx_read_result {
83 Ok(bytes) => bytes,
84 Err(_) => {
85 writer.shutdown().await.unwrap_or_default();
87 break;
88 }
89 };
90 if writer.write_all_buf(&mut bytes).await.is_err() {
91 break;
92 }
93 if writer.flush().await.is_err() {
94 break;
95 }
96 }
97 inner_rx.close();
98 });
99 let tx = futures_util::sink::unfold(SenderWrap { inner: tx }, async move |tx, data: Bytes| {
100 tx.send(data).await.map_err(std::io::Error::other).map(|_| tx)
101 });
102 Self {
103 rx: Box::pin(rx),
104 tx: Box::pin(tx),
105 read_cancel,
106 write_cancel,
107 }
108 }
109}
110
111impl Sink<Bytes> for SendRwStream {
112 type Error = std::io::Error;
113
114 #[inline]
115 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 Pin::new(&mut self.tx).poll_close(cx)
117 }
118
119 #[inline]
120 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 Pin::new(&mut self.tx).poll_flush(cx)
122 }
123
124 #[inline]
125 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
126 Pin::new(&mut self.tx).poll_ready(cx)
127 }
128
129 #[inline]
130 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
131 Pin::new(&mut self.tx).start_send(item)
132 }
133}
134
135impl Stream for SendRwStream {
136 type Item = Result<Bytes, std::io::Error>;
137
138 #[inline]
139 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 Pin::new(&mut self.rx).poll_next(cx)
141 }
142}
143
144impl Drop for SendRwStream {
145 #[inline]
146 fn drop(&mut self) {
147 self.rx.close();
148 self.read_cancel.cancel();
149 self.write_cancel.cancel();
150 }
151}
152
153struct SenderWrap<T> {
154 inner: Sender<T>,
155}
156
157impl<T> SenderWrap<T> {
158 #[inline]
159 fn send(&self, data: T) -> async_channel::Send<'_, T> {
160 self.inner.send(data)
161 }
162}
163
164impl<T> Drop for SenderWrap<T> {
165 #[inline]
166 fn drop(&mut self) {
167 self.inner.close();
168 }
169}