• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use futures_core::ready;
2 use pin_project_lite::pin_project;
3 use std::io::{IoSlice, Result};
4 use std::pin::Pin;
5 use std::task::{Context, Poll};
6 
7 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8 
9 pin_project! {
10     /// An adapter that lets you inspect the data that's being read.
11     ///
12     /// This is useful for things like hashing data as it's read in.
13     pub struct InspectReader<R, F> {
14         #[pin]
15         reader: R,
16         f: F,
17     }
18 }
19 
20 impl<R, F> InspectReader<R, F> {
21     /// Create a new InspectReader, wrapping `reader` and calling `f` for the
22     /// new data supplied by each read call.
23     ///
24     /// The closure will only be called with an empty slice if the inner reader
25     /// returns without reading data into the buffer. This happens at EOF, or if
26     /// `poll_read` is called with a zero-size buffer.
new(reader: R, f: F) -> InspectReader<R, F> where R: AsyncRead, F: FnMut(&[u8]),27     pub fn new(reader: R, f: F) -> InspectReader<R, F>
28     where
29         R: AsyncRead,
30         F: FnMut(&[u8]),
31     {
32         InspectReader { reader, f }
33     }
34 
35     /// Consumes the `InspectReader`, returning the wrapped reader
into_inner(self) -> R36     pub fn into_inner(self) -> R {
37         self.reader
38     }
39 }
40 
41 impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>42     fn poll_read(
43         self: Pin<&mut Self>,
44         cx: &mut Context<'_>,
45         buf: &mut ReadBuf<'_>,
46     ) -> Poll<Result<()>> {
47         let me = self.project();
48         let filled_length = buf.filled().len();
49         ready!(me.reader.poll_read(cx, buf))?;
50         (me.f)(&buf.filled()[filled_length..]);
51         Poll::Ready(Ok(()))
52     }
53 }
54 
55 impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<std::result::Result<usize, std::io::Error>>56     fn poll_write(
57         self: Pin<&mut Self>,
58         cx: &mut Context<'_>,
59         buf: &[u8],
60     ) -> Poll<std::result::Result<usize, std::io::Error>> {
61         self.project().reader.poll_write(cx, buf)
62     }
63 
poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>64     fn poll_flush(
65         self: Pin<&mut Self>,
66         cx: &mut Context<'_>,
67     ) -> Poll<std::result::Result<(), std::io::Error>> {
68         self.project().reader.poll_flush(cx)
69     }
70 
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>71     fn poll_shutdown(
72         self: Pin<&mut Self>,
73         cx: &mut Context<'_>,
74     ) -> Poll<std::result::Result<(), std::io::Error>> {
75         self.project().reader.poll_shutdown(cx)
76     }
77 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>78     fn poll_write_vectored(
79         self: Pin<&mut Self>,
80         cx: &mut Context<'_>,
81         bufs: &[IoSlice<'_>],
82     ) -> Poll<Result<usize>> {
83         self.project().reader.poll_write_vectored(cx, bufs)
84     }
85 
is_write_vectored(&self) -> bool86     fn is_write_vectored(&self) -> bool {
87         self.reader.is_write_vectored()
88     }
89 }
90 
91 pin_project! {
92     /// An adapter that lets you inspect the data that's being written.
93     ///
94     /// This is useful for things like hashing data as it's written out.
95     pub struct InspectWriter<W, F> {
96         #[pin]
97         writer: W,
98         f: F,
99     }
100 }
101 
102 impl<W, F> InspectWriter<W, F> {
103     /// Create a new InspectWriter, wrapping `write` and calling `f` for the
104     /// data successfully written by each write call.
105     ///
106     /// The closure `f` will never be called with an empty slice. A vectored
107     /// write can result in multiple calls to `f` - at most one call to `f` per
108     /// buffer supplied to `poll_write_vectored`.
new(writer: W, f: F) -> InspectWriter<W, F> where W: AsyncWrite, F: FnMut(&[u8]),109     pub fn new(writer: W, f: F) -> InspectWriter<W, F>
110     where
111         W: AsyncWrite,
112         F: FnMut(&[u8]),
113     {
114         InspectWriter { writer, f }
115     }
116 
117     /// Consumes the `InspectWriter`, returning the wrapped writer
into_inner(self) -> W118     pub fn into_inner(self) -> W {
119         self.writer
120     }
121 }
122 
123 impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>124     fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
125         let me = self.project();
126         let res = me.writer.poll_write(cx, buf);
127         if let Poll::Ready(Ok(count)) = res {
128             if count != 0 {
129                 (me.f)(&buf[..count]);
130             }
131         }
132         res
133     }
134 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>135     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
136         let me = self.project();
137         me.writer.poll_flush(cx)
138     }
139 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>140     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
141         let me = self.project();
142         me.writer.poll_shutdown(cx)
143     }
144 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>145     fn poll_write_vectored(
146         self: Pin<&mut Self>,
147         cx: &mut Context<'_>,
148         bufs: &[IoSlice<'_>],
149     ) -> Poll<Result<usize>> {
150         let me = self.project();
151         let res = me.writer.poll_write_vectored(cx, bufs);
152         if let Poll::Ready(Ok(mut count)) = res {
153             for buf in bufs {
154                 if count == 0 {
155                     break;
156                 }
157                 let size = count.min(buf.len());
158                 if size != 0 {
159                     (me.f)(&buf[..size]);
160                     count -= size;
161                 }
162             }
163         }
164         res
165     }
166 
is_write_vectored(&self) -> bool167     fn is_write_vectored(&self) -> bool {
168         self.writer.is_write_vectored()
169     }
170 }
171 
172 impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>173     fn poll_read(
174         self: Pin<&mut Self>,
175         cx: &mut Context<'_>,
176         buf: &mut ReadBuf<'_>,
177     ) -> Poll<std::io::Result<()>> {
178         self.project().writer.poll_read(cx, buf)
179     }
180 }
181