• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Compatibility between the `tokio::io` and `futures-io` versions of the
2 //! `AsyncRead` and `AsyncWrite` traits.
3 use futures_core::ready;
4 use pin_project_lite::pin_project;
5 use std::io;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 
9 pin_project! {
10     /// A compatibility layer that allows conversion between the
11     /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
12     #[derive(Copy, Clone, Debug)]
13     pub struct Compat<T> {
14         #[pin]
15         inner: T,
16         seek_pos: Option<io::SeekFrom>,
17     }
18 }
19 
20 /// Extension trait that allows converting a type implementing
21 /// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
22 pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
23     /// Wraps `self` with a compatibility layer that implements
24     /// `tokio_io::AsyncRead`.
compat(self) -> Compat<Self> where Self: Sized,25     fn compat(self) -> Compat<Self>
26     where
27         Self: Sized,
28     {
29         Compat::new(self)
30     }
31 }
32 
33 impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
34 
35 /// Extension trait that allows converting a type implementing
36 /// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
37 pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
38     /// Wraps `self` with a compatibility layer that implements
39     /// `tokio::io::AsyncWrite`.
compat_write(self) -> Compat<Self> where Self: Sized,40     fn compat_write(self) -> Compat<Self>
41     where
42         Self: Sized,
43     {
44         Compat::new(self)
45     }
46 }
47 
48 impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
49 
50 /// Extension trait that allows converting a type implementing
51 /// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
52 pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
53     /// Wraps `self` with a compatibility layer that implements
54     /// `futures_io::AsyncRead`.
compat(self) -> Compat<Self> where Self: Sized,55     fn compat(self) -> Compat<Self>
56     where
57         Self: Sized,
58     {
59         Compat::new(self)
60     }
61 }
62 
63 impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
64 
65 /// Extension trait that allows converting a type implementing
66 /// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
67 pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
68     /// Wraps `self` with a compatibility layer that implements
69     /// `futures_io::AsyncWrite`.
compat_write(self) -> Compat<Self> where Self: Sized,70     fn compat_write(self) -> Compat<Self>
71     where
72         Self: Sized,
73     {
74         Compat::new(self)
75     }
76 }
77 
78 impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
79 
80 // === impl Compat ===
81 
82 impl<T> Compat<T> {
new(inner: T) -> Self83     fn new(inner: T) -> Self {
84         Self {
85             inner,
86             seek_pos: None,
87         }
88     }
89 
90     /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
91     /// contained within.
get_ref(&self) -> &T92     pub fn get_ref(&self) -> &T {
93         &self.inner
94     }
95 
96     /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
97     /// contained within.
get_mut(&mut self) -> &mut T98     pub fn get_mut(&mut self) -> &mut T {
99         &mut self.inner
100     }
101 
102     /// Returns the wrapped item.
into_inner(self) -> T103     pub fn into_inner(self) -> T {
104         self.inner
105     }
106 }
107 
108 impl<T> tokio::io::AsyncRead for Compat<T>
109 where
110     T: futures_io::AsyncRead,
111 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<()>>112     fn poll_read(
113         self: Pin<&mut Self>,
114         cx: &mut Context<'_>,
115         buf: &mut tokio::io::ReadBuf<'_>,
116     ) -> Poll<io::Result<()>> {
117         // We can't trust the inner type to not peak at the bytes,
118         // so we must defensively initialize the buffer.
119         let slice = buf.initialize_unfilled();
120         let n = ready!(futures_io::AsyncRead::poll_read(
121             self.project().inner,
122             cx,
123             slice
124         ))?;
125         buf.advance(n);
126         Poll::Ready(Ok(()))
127     }
128 }
129 
130 impl<T> futures_io::AsyncRead for Compat<T>
131 where
132     T: tokio::io::AsyncRead,
133 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, slice: &mut [u8], ) -> Poll<io::Result<usize>>134     fn poll_read(
135         self: Pin<&mut Self>,
136         cx: &mut Context<'_>,
137         slice: &mut [u8],
138     ) -> Poll<io::Result<usize>> {
139         let mut buf = tokio::io::ReadBuf::new(slice);
140         ready!(tokio::io::AsyncRead::poll_read(
141             self.project().inner,
142             cx,
143             &mut buf
144         ))?;
145         Poll::Ready(Ok(buf.filled().len()))
146     }
147 }
148 
149 impl<T> tokio::io::AsyncBufRead for Compat<T>
150 where
151     T: futures_io::AsyncBufRead,
152 {
poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&'a [u8]>>153     fn poll_fill_buf<'a>(
154         self: Pin<&'a mut Self>,
155         cx: &mut Context<'_>,
156     ) -> Poll<io::Result<&'a [u8]>> {
157         futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
158     }
159 
consume(self: Pin<&mut Self>, amt: usize)160     fn consume(self: Pin<&mut Self>, amt: usize) {
161         futures_io::AsyncBufRead::consume(self.project().inner, amt)
162     }
163 }
164 
165 impl<T> futures_io::AsyncBufRead for Compat<T>
166 where
167     T: tokio::io::AsyncBufRead,
168 {
poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&'a [u8]>>169     fn poll_fill_buf<'a>(
170         self: Pin<&'a mut Self>,
171         cx: &mut Context<'_>,
172     ) -> Poll<io::Result<&'a [u8]>> {
173         tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
174     }
175 
consume(self: Pin<&mut Self>, amt: usize)176     fn consume(self: Pin<&mut Self>, amt: usize) {
177         tokio::io::AsyncBufRead::consume(self.project().inner, amt)
178     }
179 }
180 
181 impl<T> tokio::io::AsyncWrite for Compat<T>
182 where
183     T: futures_io::AsyncWrite,
184 {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>185     fn poll_write(
186         self: Pin<&mut Self>,
187         cx: &mut Context<'_>,
188         buf: &[u8],
189     ) -> Poll<io::Result<usize>> {
190         futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
191     }
192 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>193     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194         futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
195     }
196 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>197     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198         futures_io::AsyncWrite::poll_close(self.project().inner, cx)
199     }
200 }
201 
202 impl<T> futures_io::AsyncWrite for Compat<T>
203 where
204     T: tokio::io::AsyncWrite,
205 {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>206     fn poll_write(
207         self: Pin<&mut Self>,
208         cx: &mut Context<'_>,
209         buf: &[u8],
210     ) -> Poll<io::Result<usize>> {
211         tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
212     }
213 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>214     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
215         tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
216     }
217 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>218     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
219         tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
220     }
221 }
222 
223 impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
poll_seek( mut self: Pin<&mut Self>, cx: &mut Context<'_>, pos: io::SeekFrom, ) -> Poll<io::Result<u64>>224     fn poll_seek(
225         mut self: Pin<&mut Self>,
226         cx: &mut Context<'_>,
227         pos: io::SeekFrom,
228     ) -> Poll<io::Result<u64>> {
229         if self.seek_pos != Some(pos) {
230             // Ensure previous seeks have finished before starting a new one
231             ready!(self.as_mut().project().inner.poll_complete(cx))?;
232             self.as_mut().project().inner.start_seek(pos)?;
233             *self.as_mut().project().seek_pos = Some(pos);
234         }
235         let res = ready!(self.as_mut().project().inner.poll_complete(cx));
236         *self.as_mut().project().seek_pos = None;
237         Poll::Ready(res)
238     }
239 }
240 
241 impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()>242     fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
243         *self.as_mut().project().seek_pos = Some(pos);
244         Ok(())
245     }
246 
poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>247     fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
248         let pos = match self.seek_pos {
249             None => {
250                 // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
251                 // We don't have to guarantee that the value returned by
252                 // poll_complete called without start_seek is correct,
253                 // so we'll return 0.
254                 return Poll::Ready(Ok(0));
255             }
256             Some(pos) => pos,
257         };
258         let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
259         *self.as_mut().project().seek_pos = None;
260         Poll::Ready(res)
261     }
262 }
263 
264 #[cfg(unix)]
265 impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
as_raw_fd(&self) -> std::os::unix::io::RawFd266     fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
267         self.inner.as_raw_fd()
268     }
269 }
270 
271 #[cfg(windows)]
272 impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
as_raw_handle(&self) -> std::os::windows::io::RawHandle273     fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
274         self.inner.as_raw_handle()
275     }
276 }
277