• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::cmp;
2 use std::fmt;
3 #[cfg(all(feature = "server", feature = "runtime"))]
4 use std::future::Future;
5 use std::io::{self, IoSlice};
6 use std::marker::Unpin;
7 use std::mem::MaybeUninit;
8 use std::pin::Pin;
9 use std::task::{Context, Poll};
10 #[cfg(all(feature = "server", feature = "runtime"))]
11 use std::time::Duration;
12 
13 use bytes::{Buf, BufMut, Bytes, BytesMut};
14 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15 #[cfg(all(feature = "server", feature = "runtime"))]
16 use tokio::time::Instant;
17 use tracing::{debug, trace};
18 
19 use super::{Http1Transaction, ParseContext, ParsedMessage};
20 use crate::common::buf::BufList;
21 
22 /// The initial buffer size allocated before trying to read from IO.
23 pub(crate) const INIT_BUFFER_SIZE: usize = 8192;
24 
25 /// The minimum value that can be set to max buffer size.
26 pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE;
27 
28 /// The default maximum read buffer size. If the buffer gets this big and
29 /// a message is still not complete, a `TooLarge` error is triggered.
30 // Note: if this changes, update server::conn::Http::max_buf_size docs.
31 pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
32 
33 /// The maximum number of distinct `Buf`s to hold in a list before requiring
34 /// a flush. Only affects when the buffer strategy is to queue buffers.
35 ///
36 /// Note that a flush can happen before reaching the maximum. This simply
37 /// forces a flush if the queue gets this big.
38 const MAX_BUF_LIST_BUFFERS: usize = 16;
39 
40 pub(crate) struct Buffered<T, B> {
41     flush_pipeline: bool,
42     io: T,
43     partial_len: Option<usize>,
44     read_blocked: bool,
45     read_buf: BytesMut,
46     read_buf_strategy: ReadStrategy,
47     write_buf: WriteBuf<B>,
48 }
49 
50 impl<T, B> fmt::Debug for Buffered<T, B>
51 where
52     B: Buf,
53 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result54     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55         f.debug_struct("Buffered")
56             .field("read_buf", &self.read_buf)
57             .field("write_buf", &self.write_buf)
58             .finish()
59     }
60 }
61 
62 impl<T, B> Buffered<T, B>
63 where
64     T: AsyncRead + AsyncWrite + Unpin,
65     B: Buf,
66 {
new(io: T) -> Buffered<T, B>67     pub(crate) fn new(io: T) -> Buffered<T, B> {
68         let strategy = if io.is_write_vectored() {
69             WriteStrategy::Queue
70         } else {
71             WriteStrategy::Flatten
72         };
73         let write_buf = WriteBuf::new(strategy);
74         Buffered {
75             flush_pipeline: false,
76             io,
77             partial_len: None,
78             read_blocked: false,
79             read_buf: BytesMut::with_capacity(0),
80             read_buf_strategy: ReadStrategy::default(),
81             write_buf,
82         }
83     }
84 
85     #[cfg(feature = "server")]
set_flush_pipeline(&mut self, enabled: bool)86     pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
87         debug_assert!(!self.write_buf.has_remaining());
88         self.flush_pipeline = enabled;
89         if enabled {
90             self.set_write_strategy_flatten();
91         }
92     }
93 
set_max_buf_size(&mut self, max: usize)94     pub(crate) fn set_max_buf_size(&mut self, max: usize) {
95         assert!(
96             max >= MINIMUM_MAX_BUFFER_SIZE,
97             "The max_buf_size cannot be smaller than {}.",
98             MINIMUM_MAX_BUFFER_SIZE,
99         );
100         self.read_buf_strategy = ReadStrategy::with_max(max);
101         self.write_buf.max_buf_size = max;
102     }
103 
104     #[cfg(feature = "client")]
set_read_buf_exact_size(&mut self, sz: usize)105     pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
106         self.read_buf_strategy = ReadStrategy::Exact(sz);
107     }
108 
set_write_strategy_flatten(&mut self)109     pub(crate) fn set_write_strategy_flatten(&mut self) {
110         // this should always be called only at construction time,
111         // so this assert is here to catch myself
112         debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
113         self.write_buf.set_strategy(WriteStrategy::Flatten);
114     }
115 
set_write_strategy_queue(&mut self)116     pub(crate) fn set_write_strategy_queue(&mut self) {
117         // this should always be called only at construction time,
118         // so this assert is here to catch myself
119         debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
120         self.write_buf.set_strategy(WriteStrategy::Queue);
121     }
122 
read_buf(&self) -> &[u8]123     pub(crate) fn read_buf(&self) -> &[u8] {
124         self.read_buf.as_ref()
125     }
126 
127     #[cfg(test)]
128     #[cfg(feature = "nightly")]
read_buf_mut(&mut self) -> &mut BytesMut129     pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut {
130         &mut self.read_buf
131     }
132 
133     /// Return the "allocated" available space, not the potential space
134     /// that could be allocated in the future.
read_buf_remaining_mut(&self) -> usize135     fn read_buf_remaining_mut(&self) -> usize {
136         self.read_buf.capacity() - self.read_buf.len()
137     }
138 
139     /// Return whether we can append to the headers buffer.
140     ///
141     /// Reasons we can't:
142     /// - The write buf is in queue mode, and some of the past body is still
143     ///   needing to be flushed.
can_headers_buf(&self) -> bool144     pub(crate) fn can_headers_buf(&self) -> bool {
145         !self.write_buf.queue.has_remaining()
146     }
147 
headers_buf(&mut self) -> &mut Vec<u8>148     pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> {
149         let buf = self.write_buf.headers_mut();
150         &mut buf.bytes
151     }
152 
write_buf(&mut self) -> &mut WriteBuf<B>153     pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> {
154         &mut self.write_buf
155     }
156 
buffer<BB: Buf + Into<B>>(&mut self, buf: BB)157     pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) {
158         self.write_buf.buffer(buf)
159     }
160 
can_buffer(&self) -> bool161     pub(crate) fn can_buffer(&self) -> bool {
162         self.flush_pipeline || self.write_buf.can_buffer()
163     }
164 
consume_leading_lines(&mut self)165     pub(crate) fn consume_leading_lines(&mut self) {
166         if !self.read_buf.is_empty() {
167             let mut i = 0;
168             while i < self.read_buf.len() {
169                 match self.read_buf[i] {
170                     b'\r' | b'\n' => i += 1,
171                     _ => break,
172                 }
173             }
174             self.read_buf.advance(i);
175         }
176     }
177 
parse<S>( &mut self, cx: &mut Context<'_>, parse_ctx: ParseContext<'_>, ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> where S: Http1Transaction,178     pub(super) fn parse<S>(
179         &mut self,
180         cx: &mut Context<'_>,
181         parse_ctx: ParseContext<'_>,
182     ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>>
183     where
184         S: Http1Transaction,
185     {
186         loop {
187             match super::role::parse_headers::<S>(
188                 &mut self.read_buf,
189                 self.partial_len,
190                 ParseContext {
191                     cached_headers: parse_ctx.cached_headers,
192                     req_method: parse_ctx.req_method,
193                     h1_parser_config: parse_ctx.h1_parser_config.clone(),
194                     #[cfg(all(feature = "server", feature = "runtime"))]
195                     h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
196                     #[cfg(all(feature = "server", feature = "runtime"))]
197                     h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
198                     #[cfg(all(feature = "server", feature = "runtime"))]
199                     h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
200                     preserve_header_case: parse_ctx.preserve_header_case,
201                     #[cfg(feature = "ffi")]
202                     preserve_header_order: parse_ctx.preserve_header_order,
203                     h09_responses: parse_ctx.h09_responses,
204                     #[cfg(feature = "ffi")]
205                     on_informational: parse_ctx.on_informational,
206                     #[cfg(feature = "ffi")]
207                     raw_headers: parse_ctx.raw_headers,
208                 },
209             )? {
210                 Some(msg) => {
211                     debug!("parsed {} headers", msg.head.headers.len());
212 
213                     #[cfg(all(feature = "server", feature = "runtime"))]
214                     {
215                         *parse_ctx.h1_header_read_timeout_running = false;
216 
217                         if let Some(h1_header_read_timeout_fut) =
218                             parse_ctx.h1_header_read_timeout_fut
219                         {
220                             // Reset the timer in order to avoid woken up when the timeout finishes
221                             h1_header_read_timeout_fut
222                                 .as_mut()
223                                 .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60));
224                         }
225                     }
226                     self.partial_len = None;
227                     return Poll::Ready(Ok(msg));
228                 }
229                 None => {
230                     let max = self.read_buf_strategy.max();
231                     let curr_len = self.read_buf.len();
232                     if curr_len >= max {
233                         debug!("max_buf_size ({}) reached, closing", max);
234                         return Poll::Ready(Err(crate::Error::new_too_large()));
235                     }
236 
237                     #[cfg(all(feature = "server", feature = "runtime"))]
238                     if *parse_ctx.h1_header_read_timeout_running {
239                         if let Some(h1_header_read_timeout_fut) =
240                             parse_ctx.h1_header_read_timeout_fut
241                         {
242                             if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
243                                 *parse_ctx.h1_header_read_timeout_running = false;
244 
245                                 tracing::warn!("read header from client timeout");
246                                 return Poll::Ready(Err(crate::Error::new_header_timeout()));
247                             }
248                         }
249                     }
250                     if curr_len > 0 {
251                         trace!("partial headers; {} bytes so far", curr_len);
252                         self.partial_len = Some(curr_len);
253                     } else {
254                         // 1xx gobled some bytes
255                         self.partial_len = None;
256                     }
257                 }
258             }
259             if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
260                 trace!("parse eof");
261                 return Poll::Ready(Err(crate::Error::new_incomplete()));
262             }
263         }
264     }
265 
poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>>266     pub(crate) fn poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
267         self.read_blocked = false;
268         let next = self.read_buf_strategy.next();
269         if self.read_buf_remaining_mut() < next {
270             self.read_buf.reserve(next);
271         }
272 
273         let dst = self.read_buf.chunk_mut();
274         let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
275         let mut buf = ReadBuf::uninit(dst);
276         match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
277             Poll::Ready(Ok(_)) => {
278                 let n = buf.filled().len();
279                 trace!("received {} bytes", n);
280                 unsafe {
281                     // Safety: we just read that many bytes into the
282                     // uninitialized part of the buffer, so this is okay.
283                     // @tokio pls give me back `poll_read_buf` thanks
284                     self.read_buf.advance_mut(n);
285                 }
286                 self.read_buf_strategy.record(n);
287                 Poll::Ready(Ok(n))
288             }
289             Poll::Pending => {
290                 self.read_blocked = true;
291                 Poll::Pending
292             }
293             Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
294         }
295     }
296 
into_inner(self) -> (T, Bytes)297     pub(crate) fn into_inner(self) -> (T, Bytes) {
298         (self.io, self.read_buf.freeze())
299     }
300 
io_mut(&mut self) -> &mut T301     pub(crate) fn io_mut(&mut self) -> &mut T {
302         &mut self.io
303     }
304 
is_read_blocked(&self) -> bool305     pub(crate) fn is_read_blocked(&self) -> bool {
306         self.read_blocked
307     }
308 
poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>309     pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
310         if self.flush_pipeline && !self.read_buf.is_empty() {
311             Poll::Ready(Ok(()))
312         } else if self.write_buf.remaining() == 0 {
313             Pin::new(&mut self.io).poll_flush(cx)
314         } else {
315             if let WriteStrategy::Flatten = self.write_buf.strategy {
316                 return self.poll_flush_flattened(cx);
317             }
318 
319             const MAX_WRITEV_BUFS: usize = 64;
320             loop {
321                 let n = {
322                     let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS];
323                     let len = self.write_buf.chunks_vectored(&mut iovs);
324                     ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))?
325                 };
326                 // TODO(eliza): we have to do this manually because
327                 // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
328                 // `poll_write_buf` comes back, the manual advance will need to leave!
329                 self.write_buf.advance(n);
330                 debug!("flushed {} bytes", n);
331                 if self.write_buf.remaining() == 0 {
332                     break;
333                 } else if n == 0 {
334                     trace!(
335                         "write returned zero, but {} bytes remaining",
336                         self.write_buf.remaining()
337                     );
338                     return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
339                 }
340             }
341             Pin::new(&mut self.io).poll_flush(cx)
342         }
343     }
344 
345     /// Specialized version of `flush` when strategy is Flatten.
346     ///
347     /// Since all buffered bytes are flattened into the single headers buffer,
348     /// that skips some bookkeeping around using multiple buffers.
poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>349     fn poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
350         loop {
351             let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?;
352             debug!("flushed {} bytes", n);
353             self.write_buf.headers.advance(n);
354             if self.write_buf.headers.remaining() == 0 {
355                 self.write_buf.headers.reset();
356                 break;
357             } else if n == 0 {
358                 trace!(
359                     "write returned zero, but {} bytes remaining",
360                     self.write_buf.remaining()
361                 );
362                 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
363             }
364         }
365         Pin::new(&mut self.io).poll_flush(cx)
366     }
367 
368     #[cfg(test)]
flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a369     fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a {
370         futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
371     }
372 }
373 
374 // The `B` is a `Buf`, we never project a pin to it
375 impl<T: Unpin, B> Unpin for Buffered<T, B> {}
376 
377 // TODO: This trait is old... at least rename to PollBytes or something...
378 pub(crate) trait MemRead {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>379     fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>;
380 }
381 
382 impl<T, B> MemRead for Buffered<T, B>
383 where
384     T: AsyncRead + AsyncWrite + Unpin,
385     B: Buf,
386 {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>387     fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
388         if !self.read_buf.is_empty() {
389             let n = std::cmp::min(len, self.read_buf.len());
390             Poll::Ready(Ok(self.read_buf.split_to(n).freeze()))
391         } else {
392             let n = ready!(self.poll_read_from_io(cx))?;
393             Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()))
394         }
395     }
396 }
397 
398 #[derive(Clone, Copy, Debug)]
399 enum ReadStrategy {
400     Adaptive {
401         decrease_now: bool,
402         next: usize,
403         max: usize,
404     },
405     #[cfg(feature = "client")]
406     Exact(usize),
407 }
408 
409 impl ReadStrategy {
with_max(max: usize) -> ReadStrategy410     fn with_max(max: usize) -> ReadStrategy {
411         ReadStrategy::Adaptive {
412             decrease_now: false,
413             next: INIT_BUFFER_SIZE,
414             max,
415         }
416     }
417 
next(&self) -> usize418     fn next(&self) -> usize {
419         match *self {
420             ReadStrategy::Adaptive { next, .. } => next,
421             #[cfg(feature = "client")]
422             ReadStrategy::Exact(exact) => exact,
423         }
424     }
425 
max(&self) -> usize426     fn max(&self) -> usize {
427         match *self {
428             ReadStrategy::Adaptive { max, .. } => max,
429             #[cfg(feature = "client")]
430             ReadStrategy::Exact(exact) => exact,
431         }
432     }
433 
record(&mut self, bytes_read: usize)434     fn record(&mut self, bytes_read: usize) {
435         match *self {
436             ReadStrategy::Adaptive {
437                 ref mut decrease_now,
438                 ref mut next,
439                 max,
440                 ..
441             } => {
442                 if bytes_read >= *next {
443                     *next = cmp::min(incr_power_of_two(*next), max);
444                     *decrease_now = false;
445                 } else {
446                     let decr_to = prev_power_of_two(*next);
447                     if bytes_read < decr_to {
448                         if *decrease_now {
449                             *next = cmp::max(decr_to, INIT_BUFFER_SIZE);
450                             *decrease_now = false;
451                         } else {
452                             // Decreasing is a two "record" process.
453                             *decrease_now = true;
454                         }
455                     } else {
456                         // A read within the current range should cancel
457                         // a potential decrease, since we just saw proof
458                         // that we still need this size.
459                         *decrease_now = false;
460                     }
461                 }
462             }
463             #[cfg(feature = "client")]
464             ReadStrategy::Exact(_) => (),
465         }
466     }
467 }
468 
incr_power_of_two(n: usize) -> usize469 fn incr_power_of_two(n: usize) -> usize {
470     n.saturating_mul(2)
471 }
472 
prev_power_of_two(n: usize) -> usize473 fn prev_power_of_two(n: usize) -> usize {
474     // Only way this shift can underflow is if n is less than 4.
475     // (Which would means `usize::MAX >> 64` and underflowed!)
476     debug_assert!(n >= 4);
477     (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1
478 }
479 
480 impl Default for ReadStrategy {
default() -> ReadStrategy481     fn default() -> ReadStrategy {
482         ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE)
483     }
484 }
485 
486 #[derive(Clone)]
487 pub(crate) struct Cursor<T> {
488     bytes: T,
489     pos: usize,
490 }
491 
492 impl<T: AsRef<[u8]>> Cursor<T> {
493     #[inline]
new(bytes: T) -> Cursor<T>494     pub(crate) fn new(bytes: T) -> Cursor<T> {
495         Cursor { bytes, pos: 0 }
496     }
497 }
498 
499 impl Cursor<Vec<u8>> {
500     /// If we've advanced the position a bit in this cursor, and wish to
501     /// extend the underlying vector, we may wish to unshift the "read" bytes
502     /// off, and move everything else over.
maybe_unshift(&mut self, additional: usize)503     fn maybe_unshift(&mut self, additional: usize) {
504         if self.pos == 0 {
505             // nothing to do
506             return;
507         }
508 
509         if self.bytes.capacity() - self.bytes.len() >= additional {
510             // there's room!
511             return;
512         }
513 
514         self.bytes.drain(0..self.pos);
515         self.pos = 0;
516     }
517 
reset(&mut self)518     fn reset(&mut self) {
519         self.pos = 0;
520         self.bytes.clear();
521     }
522 }
523 
524 impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result525     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
526         f.debug_struct("Cursor")
527             .field("pos", &self.pos)
528             .field("len", &self.bytes.as_ref().len())
529             .finish()
530     }
531 }
532 
533 impl<T: AsRef<[u8]>> Buf for Cursor<T> {
534     #[inline]
remaining(&self) -> usize535     fn remaining(&self) -> usize {
536         self.bytes.as_ref().len() - self.pos
537     }
538 
539     #[inline]
chunk(&self) -> &[u8]540     fn chunk(&self) -> &[u8] {
541         &self.bytes.as_ref()[self.pos..]
542     }
543 
544     #[inline]
advance(&mut self, cnt: usize)545     fn advance(&mut self, cnt: usize) {
546         debug_assert!(self.pos + cnt <= self.bytes.as_ref().len());
547         self.pos += cnt;
548     }
549 }
550 
551 // an internal buffer to collect writes before flushes
552 pub(super) struct WriteBuf<B> {
553     /// Re-usable buffer that holds message headers
554     headers: Cursor<Vec<u8>>,
555     max_buf_size: usize,
556     /// Deque of user buffers if strategy is Queue
557     queue: BufList<B>,
558     strategy: WriteStrategy,
559 }
560 
561 impl<B: Buf> WriteBuf<B> {
new(strategy: WriteStrategy) -> WriteBuf<B>562     fn new(strategy: WriteStrategy) -> WriteBuf<B> {
563         WriteBuf {
564             headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)),
565             max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
566             queue: BufList::new(),
567             strategy,
568         }
569     }
570 }
571 
572 impl<B> WriteBuf<B>
573 where
574     B: Buf,
575 {
set_strategy(&mut self, strategy: WriteStrategy)576     fn set_strategy(&mut self, strategy: WriteStrategy) {
577         self.strategy = strategy;
578     }
579 
buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB)580     pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) {
581         debug_assert!(buf.has_remaining());
582         match self.strategy {
583             WriteStrategy::Flatten => {
584                 let head = self.headers_mut();
585 
586                 head.maybe_unshift(buf.remaining());
587                 trace!(
588                     self.len = head.remaining(),
589                     buf.len = buf.remaining(),
590                     "buffer.flatten"
591                 );
592                 //perf: This is a little faster than <Vec as BufMut>>::put,
593                 //but accomplishes the same result.
594                 loop {
595                     let adv = {
596                         let slice = buf.chunk();
597                         if slice.is_empty() {
598                             return;
599                         }
600                         head.bytes.extend_from_slice(slice);
601                         slice.len()
602                     };
603                     buf.advance(adv);
604                 }
605             }
606             WriteStrategy::Queue => {
607                 trace!(
608                     self.len = self.remaining(),
609                     buf.len = buf.remaining(),
610                     "buffer.queue"
611                 );
612                 self.queue.push(buf.into());
613             }
614         }
615     }
616 
can_buffer(&self) -> bool617     fn can_buffer(&self) -> bool {
618         match self.strategy {
619             WriteStrategy::Flatten => self.remaining() < self.max_buf_size,
620             WriteStrategy::Queue => {
621                 self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size
622             }
623         }
624     }
625 
headers_mut(&mut self) -> &mut Cursor<Vec<u8>>626     fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> {
627         debug_assert!(!self.queue.has_remaining());
628         &mut self.headers
629     }
630 }
631 
632 impl<B: Buf> fmt::Debug for WriteBuf<B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result633     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634         f.debug_struct("WriteBuf")
635             .field("remaining", &self.remaining())
636             .field("strategy", &self.strategy)
637             .finish()
638     }
639 }
640 
641 impl<B: Buf> Buf for WriteBuf<B> {
642     #[inline]
remaining(&self) -> usize643     fn remaining(&self) -> usize {
644         self.headers.remaining() + self.queue.remaining()
645     }
646 
647     #[inline]
chunk(&self) -> &[u8]648     fn chunk(&self) -> &[u8] {
649         let headers = self.headers.chunk();
650         if !headers.is_empty() {
651             headers
652         } else {
653             self.queue.chunk()
654         }
655     }
656 
657     #[inline]
advance(&mut self, cnt: usize)658     fn advance(&mut self, cnt: usize) {
659         let hrem = self.headers.remaining();
660 
661         match hrem.cmp(&cnt) {
662             cmp::Ordering::Equal => self.headers.reset(),
663             cmp::Ordering::Greater => self.headers.advance(cnt),
664             cmp::Ordering::Less => {
665                 let qcnt = cnt - hrem;
666                 self.headers.reset();
667                 self.queue.advance(qcnt);
668             }
669         }
670     }
671 
672     #[inline]
chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize673     fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
674         let n = self.headers.chunks_vectored(dst);
675         self.queue.chunks_vectored(&mut dst[n..]) + n
676     }
677 }
678 
679 #[derive(Debug)]
680 enum WriteStrategy {
681     Flatten,
682     Queue,
683 }
684 
685 #[cfg(test)]
686 mod tests {
687     use super::*;
688     use std::time::Duration;
689 
690     use tokio_test::io::Builder as Mock;
691 
692     // #[cfg(feature = "nightly")]
693     // use test::Bencher;
694 
695     /*
696     impl<T: Read> MemRead for AsyncIo<T> {
697         fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
698             let mut v = vec![0; len];
699             let n = try_nb!(self.read(v.as_mut_slice()));
700             Ok(Async::Ready(BytesMut::from(&v[..n]).freeze()))
701         }
702     }
703     */
704 
705     #[tokio::test]
706     #[ignore]
iobuf_write_empty_slice()707     async fn iobuf_write_empty_slice() {
708         // TODO(eliza): can i have writev back pls T_T
709         // // First, let's just check that the Mock would normally return an
710         // // error on an unexpected write, even if the buffer is empty...
711         // let mut mock = Mock::new().build();
712         // futures_util::future::poll_fn(|cx| {
713         //     Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
714         // })
715         // .await
716         // .expect_err("should be a broken pipe");
717 
718         // // underlying io will return the logic error upon write,
719         // // so we are testing that the io_buf does not trigger a write
720         // // when there is nothing to flush
721         // let mock = Mock::new().build();
722         // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
723         // io_buf.flush().await.expect("should short-circuit flush");
724     }
725 
726     #[tokio::test]
parse_reads_until_blocked()727     async fn parse_reads_until_blocked() {
728         use crate::proto::h1::ClientTransaction;
729 
730         let _ = pretty_env_logger::try_init();
731         let mock = Mock::new()
732             // Split over multiple reads will read all of it
733             .read(b"HTTP/1.1 200 OK\r\n")
734             .read(b"Server: hyper\r\n")
735             // missing last line ending
736             .wait(Duration::from_secs(1))
737             .build();
738 
739         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
740 
741         // We expect a `parse` to be not ready, and so can't await it directly.
742         // Rather, this `poll_fn` will wrap the `Poll` result.
743         futures_util::future::poll_fn(|cx| {
744             let parse_ctx = ParseContext {
745                 cached_headers: &mut None,
746                 req_method: &mut None,
747                 h1_parser_config: Default::default(),
748                 #[cfg(feature = "runtime")]
749                 h1_header_read_timeout: None,
750                 #[cfg(feature = "runtime")]
751                 h1_header_read_timeout_fut: &mut None,
752                 #[cfg(feature = "runtime")]
753                 h1_header_read_timeout_running: &mut false,
754                 preserve_header_case: false,
755                 #[cfg(feature = "ffi")]
756                 preserve_header_order: false,
757                 h09_responses: false,
758                 #[cfg(feature = "ffi")]
759                 on_informational: &mut None,
760                 #[cfg(feature = "ffi")]
761                 raw_headers: false,
762             };
763             assert!(buffered
764                 .parse::<ClientTransaction>(cx, parse_ctx)
765                 .is_pending());
766             Poll::Ready(())
767         })
768         .await;
769 
770         assert_eq!(
771             buffered.read_buf,
772             b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..]
773         );
774     }
775 
776     #[test]
read_strategy_adaptive_increments()777     fn read_strategy_adaptive_increments() {
778         let mut strategy = ReadStrategy::default();
779         assert_eq!(strategy.next(), 8192);
780 
781         // Grows if record == next
782         strategy.record(8192);
783         assert_eq!(strategy.next(), 16384);
784 
785         strategy.record(16384);
786         assert_eq!(strategy.next(), 32768);
787 
788         // Enormous records still increment at same rate
789         strategy.record(::std::usize::MAX);
790         assert_eq!(strategy.next(), 65536);
791 
792         let max = strategy.max();
793         while strategy.next() < max {
794             strategy.record(max);
795         }
796 
797         assert_eq!(strategy.next(), max, "never goes over max");
798         strategy.record(max + 1);
799         assert_eq!(strategy.next(), max, "never goes over max");
800     }
801 
802     #[test]
read_strategy_adaptive_decrements()803     fn read_strategy_adaptive_decrements() {
804         let mut strategy = ReadStrategy::default();
805         strategy.record(8192);
806         assert_eq!(strategy.next(), 16384);
807 
808         strategy.record(1);
809         assert_eq!(
810             strategy.next(),
811             16384,
812             "first smaller record doesn't decrement yet"
813         );
814         strategy.record(8192);
815         assert_eq!(strategy.next(), 16384, "record was with range");
816 
817         strategy.record(1);
818         assert_eq!(
819             strategy.next(),
820             16384,
821             "in-range record should make this the 'first' again"
822         );
823 
824         strategy.record(1);
825         assert_eq!(strategy.next(), 8192, "second smaller record decrements");
826 
827         strategy.record(1);
828         assert_eq!(strategy.next(), 8192, "first doesn't decrement");
829         strategy.record(1);
830         assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum");
831     }
832 
833     #[test]
read_strategy_adaptive_stays_the_same()834     fn read_strategy_adaptive_stays_the_same() {
835         let mut strategy = ReadStrategy::default();
836         strategy.record(8192);
837         assert_eq!(strategy.next(), 16384);
838 
839         strategy.record(8193);
840         assert_eq!(
841             strategy.next(),
842             16384,
843             "first smaller record doesn't decrement yet"
844         );
845 
846         strategy.record(8193);
847         assert_eq!(
848             strategy.next(),
849             16384,
850             "with current step does not decrement"
851         );
852     }
853 
854     #[test]
read_strategy_adaptive_max_fuzz()855     fn read_strategy_adaptive_max_fuzz() {
856         fn fuzz(max: usize) {
857             let mut strategy = ReadStrategy::with_max(max);
858             while strategy.next() < max {
859                 strategy.record(::std::usize::MAX);
860             }
861             let mut next = strategy.next();
862             while next > 8192 {
863                 strategy.record(1);
864                 strategy.record(1);
865                 next = strategy.next();
866                 assert!(
867                     next.is_power_of_two(),
868                     "decrement should be powers of two: {} (max = {})",
869                     next,
870                     max,
871                 );
872             }
873         }
874 
875         let mut max = 8192;
876         while max < std::usize::MAX {
877             fuzz(max);
878             max = (max / 2).saturating_mul(3);
879         }
880         fuzz(::std::usize::MAX);
881     }
882 
883     #[test]
884     #[should_panic]
885     #[cfg(debug_assertions)] // needs to trigger a debug_assert
write_buf_requires_non_empty_bufs()886     fn write_buf_requires_non_empty_bufs() {
887         let mock = Mock::new().build();
888         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
889 
890         buffered.buffer(Cursor::new(Vec::new()));
891     }
892 
893     /*
894     TODO: needs tokio_test::io to allow configure write_buf calls
895     #[test]
896     fn write_buf_queue() {
897         let _ = pretty_env_logger::try_init();
898 
899         let mock = AsyncIo::new_buf(vec![], 1024);
900         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
901 
902 
903         buffered.headers_buf().extend(b"hello ");
904         buffered.buffer(Cursor::new(b"world, ".to_vec()));
905         buffered.buffer(Cursor::new(b"it's ".to_vec()));
906         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
907         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
908         buffered.flush().unwrap();
909 
910         assert_eq!(buffered.io, b"hello world, it's hyper!");
911         assert_eq!(buffered.io.num_writes(), 1);
912         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
913     }
914     */
915 
916     #[tokio::test]
write_buf_flatten()917     async fn write_buf_flatten() {
918         let _ = pretty_env_logger::try_init();
919 
920         let mock = Mock::new().write(b"hello world, it's hyper!").build();
921 
922         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
923         buffered.write_buf.set_strategy(WriteStrategy::Flatten);
924 
925         buffered.headers_buf().extend(b"hello ");
926         buffered.buffer(Cursor::new(b"world, ".to_vec()));
927         buffered.buffer(Cursor::new(b"it's ".to_vec()));
928         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
929         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
930 
931         buffered.flush().await.expect("flush");
932     }
933 
934     #[test]
write_buf_flatten_partially_flushed()935     fn write_buf_flatten_partially_flushed() {
936         let _ = pretty_env_logger::try_init();
937 
938         let b = |s: &str| Cursor::new(s.as_bytes().to_vec());
939 
940         let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten);
941 
942         write_buf.buffer(b("hello "));
943         write_buf.buffer(b("world, "));
944 
945         assert_eq!(write_buf.chunk(), b"hello world, ");
946 
947         // advance most of the way, but not all
948         write_buf.advance(11);
949 
950         assert_eq!(write_buf.chunk(), b", ");
951         assert_eq!(write_buf.headers.pos, 11);
952         assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE);
953 
954         // there's still room in the headers buffer, so just push on the end
955         write_buf.buffer(b("it's hyper!"));
956 
957         assert_eq!(write_buf.chunk(), b", it's hyper!");
958         assert_eq!(write_buf.headers.pos, 11);
959 
960         let rem1 = write_buf.remaining();
961         let cap = write_buf.headers.bytes.capacity();
962 
963         // but when this would go over capacity, don't copy the old bytes
964         write_buf.buffer(Cursor::new(vec![b'X'; cap]));
965         assert_eq!(write_buf.remaining(), cap + rem1);
966         assert_eq!(write_buf.headers.pos, 0);
967     }
968 
969     #[tokio::test]
write_buf_queue_disable_auto()970     async fn write_buf_queue_disable_auto() {
971         let _ = pretty_env_logger::try_init();
972 
973         let mock = Mock::new()
974             .write(b"hello ")
975             .write(b"world, ")
976             .write(b"it's ")
977             .write(b"hyper!")
978             .build();
979 
980         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
981         buffered.write_buf.set_strategy(WriteStrategy::Queue);
982 
983         // we have 4 buffers, and vec IO disabled, but explicitly said
984         // don't try to auto detect (via setting strategy above)
985 
986         buffered.headers_buf().extend(b"hello ");
987         buffered.buffer(Cursor::new(b"world, ".to_vec()));
988         buffered.buffer(Cursor::new(b"it's ".to_vec()));
989         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
990         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
991 
992         buffered.flush().await.expect("flush");
993 
994         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
995     }
996 
997     // #[cfg(feature = "nightly")]
998     // #[bench]
999     // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) {
1000     //     let s = "Hello, World!";
1001     //     b.bytes = s.len() as u64;
1002 
1003     //     let mut write_buf = WriteBuf::<bytes::Bytes>::new();
1004     //     write_buf.set_strategy(WriteStrategy::Flatten);
1005     //     b.iter(|| {
1006     //         let chunk = bytes::Bytes::from(s);
1007     //         write_buf.buffer(chunk);
1008     //         ::test::black_box(&write_buf);
1009     //         write_buf.headers.bytes.clear();
1010     //     })
1011     // }
1012 }
1013