• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2 
3 use std::future::Future;
4 use std::io;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 
8 #[derive(Debug)]
9 pub(super) struct CopyBuffer {
10     read_done: bool,
11     need_flush: bool,
12     pos: usize,
13     cap: usize,
14     amt: u64,
15     buf: Box<[u8]>,
16 }
17 
18 impl CopyBuffer {
new() -> Self19     pub(super) fn new() -> Self {
20         Self {
21             read_done: false,
22             need_flush: false,
23             pos: 0,
24             cap: 0,
25             amt: 0,
26             buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(),
27         }
28     }
29 
poll_fill_buf<R>( &mut self, cx: &mut Context<'_>, reader: Pin<&mut R>, ) -> Poll<io::Result<()>> where R: AsyncRead + ?Sized,30     fn poll_fill_buf<R>(
31         &mut self,
32         cx: &mut Context<'_>,
33         reader: Pin<&mut R>,
34     ) -> Poll<io::Result<()>>
35     where
36         R: AsyncRead + ?Sized,
37     {
38         let me = &mut *self;
39         let mut buf = ReadBuf::new(&mut me.buf);
40         buf.set_filled(me.cap);
41 
42         let res = reader.poll_read(cx, &mut buf);
43         if let Poll::Ready(Ok(_)) = res {
44             let filled_len = buf.filled().len();
45             me.read_done = me.cap == filled_len;
46             me.cap = filled_len;
47         }
48         res
49     }
50 
poll_write_buf<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<usize>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,51     fn poll_write_buf<R, W>(
52         &mut self,
53         cx: &mut Context<'_>,
54         mut reader: Pin<&mut R>,
55         mut writer: Pin<&mut W>,
56     ) -> Poll<io::Result<usize>>
57     where
58         R: AsyncRead + ?Sized,
59         W: AsyncWrite + ?Sized,
60     {
61         let me = &mut *self;
62         match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63             Poll::Pending => {
64                 // Top up the buffer towards full if we can read a bit more
65                 // data - this should improve the chances of a large write
66                 if !me.read_done && me.cap < me.buf.len() {
67                     ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68                 }
69                 Poll::Pending
70             }
71             res => res,
72         }
73     }
74 
poll_copy<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<u64>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,75     pub(super) fn poll_copy<R, W>(
76         &mut self,
77         cx: &mut Context<'_>,
78         mut reader: Pin<&mut R>,
79         mut writer: Pin<&mut W>,
80     ) -> Poll<io::Result<u64>>
81     where
82         R: AsyncRead + ?Sized,
83         W: AsyncWrite + ?Sized,
84     {
85         loop {
86             // If our buffer is empty, then we need to read some data to
87             // continue.
88             if self.pos == self.cap && !self.read_done {
89                 self.pos = 0;
90                 self.cap = 0;
91 
92                 match self.poll_fill_buf(cx, reader.as_mut()) {
93                     Poll::Ready(Ok(_)) => (),
94                     Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
95                     Poll::Pending => {
96                         // Try flushing when the reader has no progress to avoid deadlock
97                         // when the reader depends on buffered writer.
98                         if self.need_flush {
99                             ready!(writer.as_mut().poll_flush(cx))?;
100                             self.need_flush = false;
101                         }
102 
103                         return Poll::Pending;
104                     }
105                 }
106             }
107 
108             // If our buffer has some data, let's write it out!
109             while self.pos < self.cap {
110                 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
111                 if i == 0 {
112                     return Poll::Ready(Err(io::Error::new(
113                         io::ErrorKind::WriteZero,
114                         "write zero byte into writer",
115                     )));
116                 } else {
117                     self.pos += i;
118                     self.amt += i as u64;
119                     self.need_flush = true;
120                 }
121             }
122 
123             // If pos larger than cap, this loop will never stop.
124             // In particular, user's wrong poll_write implementation returning
125             // incorrect written length may lead to thread blocking.
126             debug_assert!(
127                 self.pos <= self.cap,
128                 "writer returned length larger than input slice"
129             );
130 
131             // If we've written all the data and we've seen EOF, flush out the
132             // data and finish the transfer.
133             if self.pos == self.cap && self.read_done {
134                 ready!(writer.as_mut().poll_flush(cx))?;
135                 return Poll::Ready(Ok(self.amt));
136             }
137         }
138     }
139 }
140 
141 /// A future that asynchronously copies the entire contents of a reader into a
142 /// writer.
143 #[derive(Debug)]
144 #[must_use = "futures do nothing unless you `.await` or poll them"]
145 struct Copy<'a, R: ?Sized, W: ?Sized> {
146     reader: &'a mut R,
147     writer: &'a mut W,
148     buf: CopyBuffer,
149 }
150 
151 cfg_io_util! {
152     /// Asynchronously copies the entire contents of a reader into a writer.
153     ///
154     /// This function returns a future that will continuously read data from
155     /// `reader` and then write it into `writer` in a streaming fashion until
156     /// `reader` returns EOF.
157     ///
158     /// On success, the total number of bytes that were copied from `reader` to
159     /// `writer` is returned.
160     ///
161     /// This is an asynchronous version of [`std::io::copy`][std].
162     ///
163     /// [std]: std::io::copy
164     ///
165     /// # Errors
166     ///
167     /// The returned future will return an error immediately if any call to
168     /// `poll_read` or `poll_write` returns an error.
169     ///
170     /// # Examples
171     ///
172     /// ```
173     /// use tokio::io;
174     ///
175     /// # async fn dox() -> std::io::Result<()> {
176     /// let mut reader: &[u8] = b"hello";
177     /// let mut writer: Vec<u8> = vec![];
178     ///
179     /// io::copy(&mut reader, &mut writer).await?;
180     ///
181     /// assert_eq!(&b"hello"[..], &writer[..]);
182     /// # Ok(())
183     /// # }
184     /// ```
185     pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
186     where
187         R: AsyncRead + Unpin + ?Sized,
188         W: AsyncWrite + Unpin + ?Sized,
189     {
190         Copy {
191             reader,
192             writer,
193             buf: CopyBuffer::new()
194         }.await
195     }
196 }
197 
198 impl<R, W> Future for Copy<'_, R, W>
199 where
200     R: AsyncRead + Unpin + ?Sized,
201     W: AsyncWrite + Unpin + ?Sized,
202 {
203     type Output = io::Result<u64>;
204 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>205     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
206         let me = &mut *self;
207 
208         me.buf
209             .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
210     }
211 }
212