1 use crate::codec::decoder::Decoder; 2 use crate::codec::encoder::Encoder; 3 4 use futures_core::Stream; 5 use tokio::io::{AsyncRead, AsyncWrite}; 6 7 use bytes::BytesMut; 8 use futures_core::ready; 9 use futures_sink::Sink; 10 use pin_project_lite::pin_project; 11 use std::borrow::{Borrow, BorrowMut}; 12 use std::io; 13 use std::pin::Pin; 14 use std::task::{Context, Poll}; 15 use tracing::trace; 16 17 pin_project! { 18 #[derive(Debug)] 19 pub(crate) struct FramedImpl<T, U, State> { 20 #[pin] 21 pub(crate) inner: T, 22 pub(crate) state: State, 23 pub(crate) codec: U, 24 } 25 } 26 27 const INITIAL_CAPACITY: usize = 8 * 1024; 28 29 #[derive(Debug)] 30 pub(crate) struct ReadFrame { 31 pub(crate) eof: bool, 32 pub(crate) is_readable: bool, 33 pub(crate) buffer: BytesMut, 34 pub(crate) has_errored: bool, 35 } 36 37 pub(crate) struct WriteFrame { 38 pub(crate) buffer: BytesMut, 39 pub(crate) backpressure_boundary: usize, 40 } 41 42 #[derive(Default)] 43 pub(crate) struct RWFrames { 44 pub(crate) read: ReadFrame, 45 pub(crate) write: WriteFrame, 46 } 47 48 impl Default for ReadFrame { default() -> Self49 fn default() -> Self { 50 Self { 51 eof: false, 52 is_readable: false, 53 buffer: BytesMut::with_capacity(INITIAL_CAPACITY), 54 has_errored: false, 55 } 56 } 57 } 58 59 impl Default for WriteFrame { default() -> Self60 fn default() -> Self { 61 Self { 62 buffer: BytesMut::with_capacity(INITIAL_CAPACITY), 63 backpressure_boundary: INITIAL_CAPACITY, 64 } 65 } 66 } 67 68 impl From<BytesMut> for ReadFrame { from(mut buffer: BytesMut) -> Self69 fn from(mut buffer: BytesMut) -> Self { 70 let size = buffer.capacity(); 71 if size < INITIAL_CAPACITY { 72 buffer.reserve(INITIAL_CAPACITY - size); 73 } 74 75 Self { 76 buffer, 77 is_readable: size > 0, 78 eof: false, 79 has_errored: false, 80 } 81 } 82 } 83 84 impl From<BytesMut> for WriteFrame { from(mut buffer: BytesMut) -> Self85 fn from(mut buffer: BytesMut) -> Self { 86 let size = buffer.capacity(); 87 if size < INITIAL_CAPACITY { 88 buffer.reserve(INITIAL_CAPACITY - size); 89 } 90 91 Self { 92 buffer, 93 backpressure_boundary: INITIAL_CAPACITY, 94 } 95 } 96 } 97 98 impl Borrow<ReadFrame> for RWFrames { borrow(&self) -> &ReadFrame99 fn borrow(&self) -> &ReadFrame { 100 &self.read 101 } 102 } 103 impl BorrowMut<ReadFrame> for RWFrames { borrow_mut(&mut self) -> &mut ReadFrame104 fn borrow_mut(&mut self) -> &mut ReadFrame { 105 &mut self.read 106 } 107 } 108 impl Borrow<WriteFrame> for RWFrames { borrow(&self) -> &WriteFrame109 fn borrow(&self) -> &WriteFrame { 110 &self.write 111 } 112 } 113 impl BorrowMut<WriteFrame> for RWFrames { borrow_mut(&mut self) -> &mut WriteFrame114 fn borrow_mut(&mut self) -> &mut WriteFrame { 115 &mut self.write 116 } 117 } 118 impl<T, U, R> Stream for FramedImpl<T, U, R> 119 where 120 T: AsyncRead, 121 U: Decoder, 122 R: BorrowMut<ReadFrame>, 123 { 124 type Item = Result<U::Item, U::Error>; 125 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>126 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 127 use crate::util::poll_read_buf; 128 129 let mut pinned = self.project(); 130 let state: &mut ReadFrame = pinned.state.borrow_mut(); 131 // The following loops implements a state machine with each state corresponding 132 // to a combination of the `is_readable` and `eof` flags. States persist across 133 // loop entries and most state transitions occur with a return. 134 // 135 // The initial state is `reading`. 136 // 137 // | state | eof | is_readable | has_errored | 138 // |---------|-------|-------------|-------------| 139 // | reading | false | false | false | 140 // | framing | false | true | false | 141 // | pausing | true | true | false | 142 // | paused | true | false | false | 143 // | errored | <any> | <any> | true | 144 // `decode_eof` returns Err 145 // ┌────────────────────────────────────────────────────────┐ 146 // `decode_eof` returns │ │ 147 // `Ok(Some)` │ │ 148 // ┌─────┐ │ `decode_eof` returns After returning │ 149 // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ 150 // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ 151 // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ 152 // Pending read │ │ │ │ │ │ 153 // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ 154 // │ │ │ ┌──────┐ │ Pending │ │ 155 // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ 156 // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ 157 // └──┬─▲────┘ └─────┬──┬┘ │ │ 158 // │ │ │ │ `decode` returns Err │ │ 159 // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ 160 // │ read returns Err │ 161 // └────────────────────────────────────────────────────────────────────────────────────────────┘ 162 loop { 163 // Return `None` if we have encountered an error from the underlying decoder 164 // See: https://github.com/tokio-rs/tokio/issues/3976 165 if state.has_errored { 166 // preparing has_errored -> paused 167 trace!("Returning None and setting paused"); 168 state.is_readable = false; 169 state.has_errored = false; 170 return Poll::Ready(None); 171 } 172 173 // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", 174 // i.e. it _might_ contain data consumable as a frame or closing frame. 175 // Both signal that there is no such data by returning `None`. 176 // 177 // If `decode` couldn't read a frame and the upstream source has returned eof, 178 // `decode_eof` will attempt to decode the remaining bytes as closing frames. 179 // 180 // If the underlying AsyncRead is resumable, we may continue after an EOF, 181 // but must finish emitting all of it's associated `decode_eof` frames. 182 // Furthermore, we don't want to emit any `decode_eof` frames on retried 183 // reads after an EOF unless we've actually read more data. 184 if state.is_readable { 185 // pausing or framing 186 if state.eof { 187 // pausing 188 let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { 189 trace!("Got an error, going to errored state"); 190 state.has_errored = true; 191 err 192 })?; 193 if frame.is_none() { 194 state.is_readable = false; // prepare pausing -> paused 195 } 196 // implicit pausing -> pausing or pausing -> paused 197 return Poll::Ready(frame.map(Ok)); 198 } 199 200 // framing 201 trace!("attempting to decode a frame"); 202 203 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { 204 trace!("Got an error, going to errored state"); 205 state.has_errored = true; 206 op 207 })? { 208 trace!("frame decoded from buffer"); 209 // implicit framing -> framing 210 return Poll::Ready(Some(Ok(frame))); 211 } 212 213 // framing -> reading 214 state.is_readable = false; 215 } 216 // reading or paused 217 // If we can't build a frame yet, try to read more data and try again. 218 // Make sure we've got room for at least one byte to read to ensure 219 // that we don't get a spurious 0 that looks like EOF. 220 state.buffer.reserve(1); 221 let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( 222 |err| { 223 trace!("Got an error, going to errored state"); 224 state.has_errored = true; 225 err 226 }, 227 )? { 228 Poll::Ready(ct) => ct, 229 // implicit reading -> reading or implicit paused -> paused 230 Poll::Pending => return Poll::Pending, 231 }; 232 if bytect == 0 { 233 if state.eof { 234 // We're already at an EOF, and since we've reached this path 235 // we're also not readable. This implies that we've already finished 236 // our `decode_eof` handling, so we can simply return `None`. 237 // implicit paused -> paused 238 return Poll::Ready(None); 239 } 240 // prepare reading -> paused 241 state.eof = true; 242 } else { 243 // prepare paused -> framing or noop reading -> framing 244 state.eof = false; 245 } 246 247 // paused -> framing or reading -> framing or reading -> pausing 248 state.is_readable = true; 249 } 250 } 251 } 252 253 impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> 254 where 255 T: AsyncWrite, 256 U: Encoder<I>, 257 U::Error: From<io::Error>, 258 W: BorrowMut<WriteFrame>, 259 { 260 type Error = U::Error; 261 poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>262 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 263 if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { 264 self.as_mut().poll_flush(cx) 265 } else { 266 Poll::Ready(Ok(())) 267 } 268 } 269 start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error>270 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { 271 let pinned = self.project(); 272 pinned 273 .codec 274 .encode(item, &mut pinned.state.borrow_mut().buffer)?; 275 Ok(()) 276 } 277 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>278 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 279 use crate::util::poll_write_buf; 280 trace!("flushing framed transport"); 281 let mut pinned = self.project(); 282 283 while !pinned.state.borrow_mut().buffer.is_empty() { 284 let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); 285 trace!(remaining = buffer.len(), "writing;"); 286 287 let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; 288 289 if n == 0 { 290 return Poll::Ready(Err(io::Error::new( 291 io::ErrorKind::WriteZero, 292 "failed to \ 293 write frame to transport", 294 ) 295 .into())); 296 } 297 } 298 299 // Try flushing the underlying IO 300 ready!(pinned.inner.poll_flush(cx))?; 301 302 trace!("framed transport flushed"); 303 Poll::Ready(Ok(())) 304 } 305 poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>306 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 307 ready!(self.as_mut().poll_flush(cx))?; 308 ready!(self.project().inner.poll_shutdown(cx))?; 309 310 Poll::Ready(Ok(())) 311 } 312 } 313