• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::codec::decoder::Decoder;
2 use crate::codec::encoder::Encoder;
3 
4 use futures_core::Stream;
5 use tokio::io::{AsyncRead, AsyncWrite};
6 
7 use bytes::BytesMut;
8 use futures_core::ready;
9 use futures_sink::Sink;
10 use pin_project_lite::pin_project;
11 use std::borrow::{Borrow, BorrowMut};
12 use std::io;
13 use std::pin::Pin;
14 use std::task::{Context, Poll};
15 use tracing::trace;
16 
17 pin_project! {
18     #[derive(Debug)]
19     pub(crate) struct FramedImpl<T, U, State> {
20         #[pin]
21         pub(crate) inner: T,
22         pub(crate) state: State,
23         pub(crate) codec: U,
24     }
25 }
26 
27 const INITIAL_CAPACITY: usize = 8 * 1024;
28 
29 #[derive(Debug)]
30 pub(crate) struct ReadFrame {
31     pub(crate) eof: bool,
32     pub(crate) is_readable: bool,
33     pub(crate) buffer: BytesMut,
34     pub(crate) has_errored: bool,
35 }
36 
37 pub(crate) struct WriteFrame {
38     pub(crate) buffer: BytesMut,
39     pub(crate) backpressure_boundary: usize,
40 }
41 
42 #[derive(Default)]
43 pub(crate) struct RWFrames {
44     pub(crate) read: ReadFrame,
45     pub(crate) write: WriteFrame,
46 }
47 
48 impl Default for ReadFrame {
default() -> Self49     fn default() -> Self {
50         Self {
51             eof: false,
52             is_readable: false,
53             buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
54             has_errored: false,
55         }
56     }
57 }
58 
59 impl Default for WriteFrame {
default() -> Self60     fn default() -> Self {
61         Self {
62             buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
63             backpressure_boundary: INITIAL_CAPACITY,
64         }
65     }
66 }
67 
68 impl From<BytesMut> for ReadFrame {
from(mut buffer: BytesMut) -> Self69     fn from(mut buffer: BytesMut) -> Self {
70         let size = buffer.capacity();
71         if size < INITIAL_CAPACITY {
72             buffer.reserve(INITIAL_CAPACITY - size);
73         }
74 
75         Self {
76             buffer,
77             is_readable: size > 0,
78             eof: false,
79             has_errored: false,
80         }
81     }
82 }
83 
84 impl From<BytesMut> for WriteFrame {
from(mut buffer: BytesMut) -> Self85     fn from(mut buffer: BytesMut) -> Self {
86         let size = buffer.capacity();
87         if size < INITIAL_CAPACITY {
88             buffer.reserve(INITIAL_CAPACITY - size);
89         }
90 
91         Self {
92             buffer,
93             backpressure_boundary: INITIAL_CAPACITY,
94         }
95     }
96 }
97 
98 impl Borrow<ReadFrame> for RWFrames {
borrow(&self) -> &ReadFrame99     fn borrow(&self) -> &ReadFrame {
100         &self.read
101     }
102 }
103 impl BorrowMut<ReadFrame> for RWFrames {
borrow_mut(&mut self) -> &mut ReadFrame104     fn borrow_mut(&mut self) -> &mut ReadFrame {
105         &mut self.read
106     }
107 }
108 impl Borrow<WriteFrame> for RWFrames {
borrow(&self) -> &WriteFrame109     fn borrow(&self) -> &WriteFrame {
110         &self.write
111     }
112 }
113 impl BorrowMut<WriteFrame> for RWFrames {
borrow_mut(&mut self) -> &mut WriteFrame114     fn borrow_mut(&mut self) -> &mut WriteFrame {
115         &mut self.write
116     }
117 }
118 impl<T, U, R> Stream for FramedImpl<T, U, R>
119 where
120     T: AsyncRead,
121     U: Decoder,
122     R: BorrowMut<ReadFrame>,
123 {
124     type Item = Result<U::Item, U::Error>;
125 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>126     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127         use crate::util::poll_read_buf;
128 
129         let mut pinned = self.project();
130         let state: &mut ReadFrame = pinned.state.borrow_mut();
131         // The following loops implements a state machine with each state corresponding
132         // to a combination of the `is_readable` and `eof` flags. States persist across
133         // loop entries and most state transitions occur with a return.
134         //
135         // The initial state is `reading`.
136         //
137         // | state   | eof   | is_readable | has_errored |
138         // |---------|-------|-------------|-------------|
139         // | reading | false | false       | false       |
140         // | framing | false | true        | false       |
141         // | pausing | true  | true        | false       |
142         // | paused  | true  | false       | false       |
143         // | errored | <any> | <any>       | true        |
144         //                                                       `decode_eof` returns Err
145         //                                          ┌────────────────────────────────────────────────────────┐
146         //                   `decode_eof` returns   │                                                        │
147         //                             `Ok(Some)`   │                                                        │
148         //                                 ┌─────┐  │     `decode_eof` returns               After returning │
149         //                Read 0 bytes     ├─────▼──┴┐    `Ok(None)`          ┌────────┐ ◄───┐ `None`    ┌───▼─────┐
150         //               ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐   └───────────┤ Errored │
151         //               │                 └─────────┘                        └─┬──▲───┘ │               └───▲───▲─┘
152         // Pending read  │                                                      │  │     │                   │   │
153         //     ┌──────┐  │            `decode` returns `Some`                   │  └─────┘                   │   │
154         //     │      │  │                   ┌──────┐                           │  Pending                   │   │
155         //     │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐     read n>0 bytes      │  read                      │   │
156         //     └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘                            │   │
157         //       └──┬─▲────┘                └─────┬──┬┘                                                      │   │
158         //          │ │                           │  │                 `decode` returns Err                  │   │
159         //          │ └───decode` returns `None`──┘  └───────────────────────────────────────────────────────┘   │
160         //          │                             read returns Err                                               │
161         //          └────────────────────────────────────────────────────────────────────────────────────────────┘
162         loop {
163             // Return `None` if we have encountered an error from the underlying decoder
164             // See: https://github.com/tokio-rs/tokio/issues/3976
165             if state.has_errored {
166                 // preparing has_errored -> paused
167                 trace!("Returning None and setting paused");
168                 state.is_readable = false;
169                 state.has_errored = false;
170                 return Poll::Ready(None);
171             }
172 
173             // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
174             // i.e. it _might_ contain data consumable as a frame or closing frame.
175             // Both signal that there is no such data by returning `None`.
176             //
177             // If `decode` couldn't read a frame and the upstream source has returned eof,
178             // `decode_eof` will attempt to decode the remaining bytes as closing frames.
179             //
180             // If the underlying AsyncRead is resumable, we may continue after an EOF,
181             // but must finish emitting all of it's associated `decode_eof` frames.
182             // Furthermore, we don't want to emit any `decode_eof` frames on retried
183             // reads after an EOF unless we've actually read more data.
184             if state.is_readable {
185                 // pausing or framing
186                 if state.eof {
187                     // pausing
188                     let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
189                         trace!("Got an error, going to errored state");
190                         state.has_errored = true;
191                         err
192                     })?;
193                     if frame.is_none() {
194                         state.is_readable = false; // prepare pausing -> paused
195                     }
196                     // implicit pausing -> pausing or pausing -> paused
197                     return Poll::Ready(frame.map(Ok));
198                 }
199 
200                 // framing
201                 trace!("attempting to decode a frame");
202 
203                 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
204                     trace!("Got an error, going to errored state");
205                     state.has_errored = true;
206                     op
207                 })? {
208                     trace!("frame decoded from buffer");
209                     // implicit framing -> framing
210                     return Poll::Ready(Some(Ok(frame)));
211                 }
212 
213                 // framing -> reading
214                 state.is_readable = false;
215             }
216             // reading or paused
217             // If we can't build a frame yet, try to read more data and try again.
218             // Make sure we've got room for at least one byte to read to ensure
219             // that we don't get a spurious 0 that looks like EOF.
220             state.buffer.reserve(1);
221             let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
222                 |err| {
223                     trace!("Got an error, going to errored state");
224                     state.has_errored = true;
225                     err
226                 },
227             )? {
228                 Poll::Ready(ct) => ct,
229                 // implicit reading -> reading or implicit paused -> paused
230                 Poll::Pending => return Poll::Pending,
231             };
232             if bytect == 0 {
233                 if state.eof {
234                     // We're already at an EOF, and since we've reached this path
235                     // we're also not readable. This implies that we've already finished
236                     // our `decode_eof` handling, so we can simply return `None`.
237                     // implicit paused -> paused
238                     return Poll::Ready(None);
239                 }
240                 // prepare reading -> paused
241                 state.eof = true;
242             } else {
243                 // prepare paused -> framing or noop reading -> framing
244                 state.eof = false;
245             }
246 
247             // paused -> framing or reading -> framing or reading -> pausing
248             state.is_readable = true;
249         }
250     }
251 }
252 
253 impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
254 where
255     T: AsyncWrite,
256     U: Encoder<I>,
257     U::Error: From<io::Error>,
258     W: BorrowMut<WriteFrame>,
259 {
260     type Error = U::Error;
261 
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>262     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263         if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
264             self.as_mut().poll_flush(cx)
265         } else {
266             Poll::Ready(Ok(()))
267         }
268     }
269 
start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error>270     fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
271         let pinned = self.project();
272         pinned
273             .codec
274             .encode(item, &mut pinned.state.borrow_mut().buffer)?;
275         Ok(())
276     }
277 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>278     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279         use crate::util::poll_write_buf;
280         trace!("flushing framed transport");
281         let mut pinned = self.project();
282 
283         while !pinned.state.borrow_mut().buffer.is_empty() {
284             let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
285             trace!(remaining = buffer.len(), "writing;");
286 
287             let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
288 
289             if n == 0 {
290                 return Poll::Ready(Err(io::Error::new(
291                     io::ErrorKind::WriteZero,
292                     "failed to \
293                      write frame to transport",
294                 )
295                 .into()));
296             }
297         }
298 
299         // Try flushing the underlying IO
300         ready!(pinned.inner.poll_flush(cx))?;
301 
302         trace!("framed transport flushed");
303         Poll::Ready(Ok(()))
304     }
305 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>306     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
307         ready!(self.as_mut().poll_flush(cx))?;
308         ready!(self.project().inner.poll_shutdown(cx))?;
309 
310         Poll::Ready(Ok(()))
311     }
312 }
313