• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use futures_core::ready;
2 use futures_sink::Sink;
3 
4 use futures_core::stream::Stream;
5 use pin_project_lite::pin_project;
6 use std::io;
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use tokio::io::{AsyncRead, AsyncWrite};
10 
11 pin_project! {
12     /// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`].
13     ///
14     /// Whenever you write to this [`SinkWriter`], the supplied bytes are
15     /// forwarded to the inner [`Sink`]. When `shutdown` is called on this
16     /// [`SinkWriter`], the inner sink is closed.
17     ///
18     /// This adapter takes a `Sink<&[u8]>` and provides an [`AsyncWrite`] impl
19     /// for it. Because of the lifetime, this trait is relatively rarely
20     /// implemented. The main ways to get a `Sink<&[u8]>` that you can use with
21     /// this type are:
22     ///
23     ///  * With the codec module by implementing the [`Encoder`]`<&[u8]>` trait.
24     ///  * By wrapping a `Sink<Bytes>` in a [`CopyToBytes`].
25     ///  * Manually implementing `Sink<&[u8]>` directly.
26     ///
27     /// The opposite conversion of implementing `Sink<_>` for an [`AsyncWrite`]
28     /// is done using the [`codec`] module.
29     ///
30     /// # Example
31     ///
32     /// ```
33     /// use bytes::Bytes;
34     /// use futures_util::SinkExt;
35     /// use std::io::{Error, ErrorKind};
36     /// use tokio::io::AsyncWriteExt;
37     /// use tokio_util::io::{SinkWriter, CopyToBytes};
38     /// use tokio_util::sync::PollSender;
39     ///
40     /// # #[tokio::main(flavor = "current_thread")]
41     /// # async fn main() -> Result<(), Error> {
42     /// // We use an mpsc channel as an example of a `Sink<Bytes>`.
43     /// let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(1);
44     /// let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
45     ///
46     /// // Wrap it in `CopyToBytes` to get a `Sink<&[u8]>`.
47     /// let mut writer = SinkWriter::new(CopyToBytes::new(sink));
48     ///
49     /// // Write data to our interface...
50     /// let data: [u8; 4] = [1, 2, 3, 4];
51     /// let _ = writer.write(&data).await?;
52     ///
53     /// // ... and receive it.
54     /// assert_eq!(data.as_slice(), &*rx.recv().await.unwrap());
55     /// # Ok(())
56     /// # }
57     /// ```
58     ///
59     /// [`AsyncWrite`]: tokio::io::AsyncWrite
60     /// [`CopyToBytes`]: crate::io::CopyToBytes
61     /// [`Encoder`]: crate::codec::Encoder
62     /// [`Sink`]: futures_sink::Sink
63     /// [`codec`]: crate::codec
64     #[derive(Debug)]
65     pub struct SinkWriter<S> {
66         #[pin]
67         inner: S,
68     }
69 }
70 
71 impl<S> SinkWriter<S> {
72     /// Creates a new [`SinkWriter`].
new(sink: S) -> Self73     pub fn new(sink: S) -> Self {
74         Self { inner: sink }
75     }
76 
77     /// Gets a reference to the underlying sink.
get_ref(&self) -> &S78     pub fn get_ref(&self) -> &S {
79         &self.inner
80     }
81 
82     /// Gets a mutable reference to the underlying sink.
get_mut(&mut self) -> &mut S83     pub fn get_mut(&mut self) -> &mut S {
84         &mut self.inner
85     }
86 
87     /// Consumes this [`SinkWriter`], returning the underlying sink.
into_inner(self) -> S88     pub fn into_inner(self) -> S {
89         self.inner
90     }
91 }
92 impl<S, E> AsyncWrite for SinkWriter<S>
93 where
94     for<'a> S: Sink<&'a [u8], Error = E>,
95     E: Into<io::Error>,
96 {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>97     fn poll_write(
98         self: Pin<&mut Self>,
99         cx: &mut Context<'_>,
100         buf: &[u8],
101     ) -> Poll<Result<usize, io::Error>> {
102         let mut this = self.project();
103 
104         ready!(this.inner.as_mut().poll_ready(cx).map_err(Into::into))?;
105         match this.inner.as_mut().start_send(buf) {
106             Ok(()) => Poll::Ready(Ok(buf.len())),
107             Err(e) => Poll::Ready(Err(e.into())),
108         }
109     }
110 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>111     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
112         self.project().inner.poll_flush(cx).map_err(Into::into)
113     }
114 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>115     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
116         self.project().inner.poll_close(cx).map_err(Into::into)
117     }
118 }
119 
120 impl<S: Stream> Stream for SinkWriter<S> {
121     type Item = S::Item;
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>122     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123         self.project().inner.poll_next(cx)
124     }
125 }
126 
127 impl<S: AsyncRead> AsyncRead for SinkWriter<S> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<()>>128     fn poll_read(
129         self: Pin<&mut Self>,
130         cx: &mut Context<'_>,
131         buf: &mut tokio::io::ReadBuf<'_>,
132     ) -> Poll<io::Result<()>> {
133         self.project().inner.poll_read(cx, buf)
134     }
135 }
136