monoio/io/util/
buf_writer.rs

1use std::{future::Future, io};
2
3use crate::{
4    buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut, IoVecWrapper, Slice},
5    io::{AsyncBufRead, AsyncReadRent, AsyncWriteRent, AsyncWriteRentExt},
6    BufResult,
7};
8
9/// BufWriter is a struct with a buffer. BufWriter implements AsyncWriteRent,
10/// and if the inner io implements AsyncReadRent, it will delegate the
11/// implementation.
12pub struct BufWriter<W> {
13    inner: W,
14    buf: Option<Box<[u8]>>,
15    pos: usize,
16    cap: usize,
17}
18
19const DEFAULT_BUF_SIZE: usize = 8 * 1024;
20
21impl<W> BufWriter<W> {
22    /// Create BufWriter with default buffer size
23    #[inline]
24    pub fn new(inner: W) -> Self {
25        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
26    }
27
28    /// Create BufWriter with given buffer size
29    #[inline]
30    pub fn with_capacity(capacity: usize, inner: W) -> Self {
31        let buffer = vec![0; capacity];
32        Self {
33            inner,
34            buf: Some(buffer.into_boxed_slice()),
35            pos: 0,
36            cap: 0,
37        }
38    }
39
40    /// Gets a reference to the underlying writer.
41    #[inline]
42    pub fn get_ref(&self) -> &W {
43        &self.inner
44    }
45
46    /// Gets a mutable reference to the underlying writer.
47    #[inline]
48    pub fn get_mut(&mut self) -> &mut W {
49        &mut self.inner
50    }
51
52    /// Consumes this `BufWriter`, returning the underlying writer.
53    ///
54    /// Note that any leftover data in the internal buffer is lost.
55    #[inline]
56    pub fn into_inner(self) -> W {
57        self.inner
58    }
59
60    /// Returns a reference to the internally buffered data.
61    #[inline]
62    pub fn buffer(&self) -> &[u8] {
63        &self.buf.as_ref().expect("unable to take buffer")[self.pos..self.cap]
64    }
65
66    /// Invalidates all data in the internal buffer.
67    #[inline]
68    fn discard_buffer(&mut self) {
69        self.pos = 0;
70        self.cap = 0;
71    }
72}
73
74impl<W: AsyncWriteRent> BufWriter<W> {
75    async fn flush_buf(&mut self) -> io::Result<()> {
76        if self.pos != self.cap {
77            // there is some data left inside internal buf
78            let buf = self
79                .buf
80                .take()
81                .expect("no buffer available, generated future must be awaited");
82            // move buf to slice and write_all
83            let slice = Slice::new(buf, self.pos, self.cap);
84            let (ret, slice) = self.inner.write_all(slice).await;
85            // move it back and return
86            self.buf = Some(slice.into_inner());
87            ret?;
88            self.discard_buffer();
89        }
90        Ok(())
91    }
92}
93
94impl<W: AsyncWriteRent> AsyncWriteRent for BufWriter<W> {
95    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
96        let owned_buf = self.buf.as_ref().unwrap();
97        let owned_len = owned_buf.len();
98        let amt = buf.bytes_init();
99
100        if self.pos + amt > owned_len {
101            // Buf can not be copied directly into OwnedBuf,
102            // we must flush OwnedBuf first.
103            match self.flush_buf().await {
104                Ok(_) => (),
105                Err(e) => {
106                    return (Err(e), buf);
107                }
108            }
109        }
110
111        // Now there are two situations here:
112        // 1. OwnedBuf has data, and self.pos + amt <= owned_len,
113        // which means the data can be copied into OwnedBuf.
114        // 2. OwnedBuf is empty. If we can copy buf into OwnedBuf,
115        // we will copy it, otherwise we will send it directly(in
116        // this situation, the OwnedBuf must be already empty).
117        if amt > owned_len {
118            self.inner.write(buf).await
119        } else {
120            unsafe {
121                let owned_buf = self.buf.as_mut().unwrap();
122                owned_buf
123                    .as_mut_ptr()
124                    .add(self.cap)
125                    .copy_from_nonoverlapping(buf.read_ptr(), amt);
126            }
127            self.cap += amt;
128            (Ok(amt), buf)
129        }
130    }
131
132    // TODO: implement it as real io_vec
133    async fn writev<T: IoVecBuf>(&mut self, buf: T) -> BufResult<usize, T> {
134        let slice = match IoVecWrapper::new(buf) {
135            Ok(slice) => slice,
136            Err(buf) => return (Ok(0), buf),
137        };
138
139        let (result, slice) = self.write(slice).await;
140        (result, slice.into_inner())
141    }
142
143    async fn flush(&mut self) -> std::io::Result<()> {
144        self.flush_buf().await?;
145        self.inner.flush().await
146    }
147
148    async fn shutdown(&mut self) -> std::io::Result<()> {
149        self.flush_buf().await?;
150        self.inner.shutdown().await
151    }
152}
153
154impl<W: AsyncWriteRent + AsyncReadRent> AsyncReadRent for BufWriter<W> {
155    #[inline]
156    fn read<T: IoBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
157        self.inner.read(buf)
158    }
159
160    #[inline]
161    fn readv<T: IoVecBufMut>(&mut self, buf: T) -> impl Future<Output = BufResult<usize, T>> {
162        self.inner.readv(buf)
163    }
164}
165
166impl<W: AsyncWriteRent + AsyncBufRead> AsyncBufRead for BufWriter<W> {
167    #[inline]
168    fn fill_buf(&mut self) -> impl Future<Output = std::io::Result<&[u8]>> {
169        self.inner.fill_buf()
170    }
171
172    #[inline]
173    fn consume(&mut self, amt: usize) {
174        self.inner.consume(amt)
175    }
176}