monoio/io/
async_buf_read_ext.rs

1use std::{
2    future::Future,
3    io::{Error, ErrorKind, Result},
4    str::from_utf8,
5};
6
7use memchr::memchr;
8
9use crate::io::AsyncBufRead;
10
11struct Guard<'a> {
12    buf: &'a mut Vec<u8>,
13    len: usize,
14}
15
16impl Drop for Guard<'_> {
17    fn drop(&mut self) {
18        unsafe {
19            self.buf.set_len(self.len);
20        }
21    }
22}
23
24async fn read_until<A>(r: &mut A, delim: u8, buf: &mut Vec<u8>) -> Result<usize>
25where
26    A: AsyncBufRead + ?Sized,
27{
28    let mut read = 0;
29    loop {
30        let (done, used) = {
31            let available = match r.fill_buf().await {
32                Ok(n) => n,
33                Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
34                Err(e) => return Err(e),
35            };
36
37            match memchr(delim, available) {
38                Some(i) => {
39                    buf.extend_from_slice(&available[..=i]);
40                    (true, i + 1)
41                }
42                None => {
43                    buf.extend_from_slice(available);
44                    (false, available.len())
45                }
46            }
47        };
48        r.consume(used);
49        read += used;
50        if done || used == 0 {
51            return Ok(read);
52        }
53    }
54}
55
56/// AsyncBufReadExt
57pub trait AsyncBufReadExt {
58    /// This function will read bytes from the underlying stream until the delimiter or EOF is
59    /// found. Once found, all bytes up to, and including, the delimiter (if found) will be appended
60    /// to buf.
61    ///
62    /// If successful, this function will return the total number of bytes read.
63    ///
64    /// # Errors
65    /// This function will ignore all instances of ErrorKind::Interrupted and will otherwise return
66    /// any errors returned by fill_buf.
67    fn read_until<'a>(
68        &'a mut self,
69        byte: u8,
70        buf: &'a mut Vec<u8>,
71    ) -> impl Future<Output = Result<usize>>;
72
73    /// This function will read bytes from the underlying stream until the newline delimiter (the
74    /// 0xA byte) or EOF is found. Once found, all bytes up to, and including, the delimiter (if
75    /// found) will be appended to buf.
76    ///
77    /// If successful, this function will return the total number of bytes read.
78    ///
79    /// If this function returns Ok(0), the stream has reached EOF.
80    ///
81    /// # Errors
82    /// This function has the same error semantics as read_until and will also return an error if
83    /// the read bytes are not valid UTF-8. If an I/O error is encountered then buf may contain some
84    /// bytes already read in the event that all data read so far was valid UTF-8.
85    fn read_line<'a>(&'a mut self, buf: &'a mut String) -> impl Future<Output = Result<usize>>;
86}
87
88impl<A> AsyncBufReadExt for A
89where
90    A: AsyncBufRead + ?Sized,
91{
92    fn read_until<'a>(
93        &'a mut self,
94        byte: u8,
95        buf: &'a mut Vec<u8>,
96    ) -> impl Future<Output = Result<usize>> {
97        read_until(self, byte, buf)
98    }
99
100    async fn read_line<'a>(&'a mut self, buf: &'a mut String) -> Result<usize> {
101        unsafe {
102            let mut g = Guard {
103                len: buf.len(),
104                buf: buf.as_mut_vec(),
105            };
106
107            let ret = read_until(self, b'\n', g.buf).await;
108            if from_utf8(&g.buf[g.len..]).is_err() {
109                ret.and_then(|_| {
110                    Err(Error::new(
111                        ErrorKind::InvalidData,
112                        "stream did not contain valid UTF-8",
113                    ))
114                })
115            } else {
116                g.len = g.buf.len();
117                ret
118            }
119        }
120    }
121}