1 use std::cmp;
2 use std::fmt;
3 #[cfg(all(feature = "server", feature = "runtime"))]
4 use std::future::Future;
5 use std::io::{self, IoSlice};
6 use std::marker::Unpin;
7 use std::mem::MaybeUninit;
8 use std::pin::Pin;
9 use std::task::{Context, Poll};
10 #[cfg(all(feature = "server", feature = "runtime"))]
11 use std::time::Duration;
12
13 use bytes::{Buf, BufMut, Bytes, BytesMut};
14 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15 #[cfg(all(feature = "server", feature = "runtime"))]
16 use tokio::time::Instant;
17 use tracing::{debug, trace};
18
19 use super::{Http1Transaction, ParseContext, ParsedMessage};
20 use crate::common::buf::BufList;
21
22 /// The initial buffer size allocated before trying to read from IO.
23 pub(crate) const INIT_BUFFER_SIZE: usize = 8192;
24
25 /// The minimum value that can be set to max buffer size.
26 pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE;
27
28 /// The default maximum read buffer size. If the buffer gets this big and
29 /// a message is still not complete, a `TooLarge` error is triggered.
30 // Note: if this changes, update server::conn::Http::max_buf_size docs.
31 pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
32
33 /// The maximum number of distinct `Buf`s to hold in a list before requiring
34 /// a flush. Only affects when the buffer strategy is to queue buffers.
35 ///
36 /// Note that a flush can happen before reaching the maximum. This simply
37 /// forces a flush if the queue gets this big.
38 const MAX_BUF_LIST_BUFFERS: usize = 16;
39
40 pub(crate) struct Buffered<T, B> {
41 flush_pipeline: bool,
42 io: T,
43 partial_len: Option<usize>,
44 read_blocked: bool,
45 read_buf: BytesMut,
46 read_buf_strategy: ReadStrategy,
47 write_buf: WriteBuf<B>,
48 }
49
50 impl<T, B> fmt::Debug for Buffered<T, B>
51 where
52 B: Buf,
53 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("Buffered")
56 .field("read_buf", &self.read_buf)
57 .field("write_buf", &self.write_buf)
58 .finish()
59 }
60 }
61
62 impl<T, B> Buffered<T, B>
63 where
64 T: AsyncRead + AsyncWrite + Unpin,
65 B: Buf,
66 {
new(io: T) -> Buffered<T, B>67 pub(crate) fn new(io: T) -> Buffered<T, B> {
68 let strategy = if io.is_write_vectored() {
69 WriteStrategy::Queue
70 } else {
71 WriteStrategy::Flatten
72 };
73 let write_buf = WriteBuf::new(strategy);
74 Buffered {
75 flush_pipeline: false,
76 io,
77 partial_len: None,
78 read_blocked: false,
79 read_buf: BytesMut::with_capacity(0),
80 read_buf_strategy: ReadStrategy::default(),
81 write_buf,
82 }
83 }
84
85 #[cfg(feature = "server")]
set_flush_pipeline(&mut self, enabled: bool)86 pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
87 debug_assert!(!self.write_buf.has_remaining());
88 self.flush_pipeline = enabled;
89 if enabled {
90 self.set_write_strategy_flatten();
91 }
92 }
93
set_max_buf_size(&mut self, max: usize)94 pub(crate) fn set_max_buf_size(&mut self, max: usize) {
95 assert!(
96 max >= MINIMUM_MAX_BUFFER_SIZE,
97 "The max_buf_size cannot be smaller than {}.",
98 MINIMUM_MAX_BUFFER_SIZE,
99 );
100 self.read_buf_strategy = ReadStrategy::with_max(max);
101 self.write_buf.max_buf_size = max;
102 }
103
104 #[cfg(feature = "client")]
set_read_buf_exact_size(&mut self, sz: usize)105 pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
106 self.read_buf_strategy = ReadStrategy::Exact(sz);
107 }
108
set_write_strategy_flatten(&mut self)109 pub(crate) fn set_write_strategy_flatten(&mut self) {
110 // this should always be called only at construction time,
111 // so this assert is here to catch myself
112 debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
113 self.write_buf.set_strategy(WriteStrategy::Flatten);
114 }
115
set_write_strategy_queue(&mut self)116 pub(crate) fn set_write_strategy_queue(&mut self) {
117 // this should always be called only at construction time,
118 // so this assert is here to catch myself
119 debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
120 self.write_buf.set_strategy(WriteStrategy::Queue);
121 }
122
read_buf(&self) -> &[u8]123 pub(crate) fn read_buf(&self) -> &[u8] {
124 self.read_buf.as_ref()
125 }
126
127 #[cfg(test)]
128 #[cfg(feature = "nightly")]
read_buf_mut(&mut self) -> &mut BytesMut129 pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut {
130 &mut self.read_buf
131 }
132
133 /// Return the "allocated" available space, not the potential space
134 /// that could be allocated in the future.
read_buf_remaining_mut(&self) -> usize135 fn read_buf_remaining_mut(&self) -> usize {
136 self.read_buf.capacity() - self.read_buf.len()
137 }
138
139 /// Return whether we can append to the headers buffer.
140 ///
141 /// Reasons we can't:
142 /// - The write buf is in queue mode, and some of the past body is still
143 /// needing to be flushed.
can_headers_buf(&self) -> bool144 pub(crate) fn can_headers_buf(&self) -> bool {
145 !self.write_buf.queue.has_remaining()
146 }
147
headers_buf(&mut self) -> &mut Vec<u8>148 pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> {
149 let buf = self.write_buf.headers_mut();
150 &mut buf.bytes
151 }
152
write_buf(&mut self) -> &mut WriteBuf<B>153 pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> {
154 &mut self.write_buf
155 }
156
buffer<BB: Buf + Into<B>>(&mut self, buf: BB)157 pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) {
158 self.write_buf.buffer(buf)
159 }
160
can_buffer(&self) -> bool161 pub(crate) fn can_buffer(&self) -> bool {
162 self.flush_pipeline || self.write_buf.can_buffer()
163 }
164
consume_leading_lines(&mut self)165 pub(crate) fn consume_leading_lines(&mut self) {
166 if !self.read_buf.is_empty() {
167 let mut i = 0;
168 while i < self.read_buf.len() {
169 match self.read_buf[i] {
170 b'\r' | b'\n' => i += 1,
171 _ => break,
172 }
173 }
174 self.read_buf.advance(i);
175 }
176 }
177
parse<S>( &mut self, cx: &mut Context<'_>, parse_ctx: ParseContext<'_>, ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> where S: Http1Transaction,178 pub(super) fn parse<S>(
179 &mut self,
180 cx: &mut Context<'_>,
181 parse_ctx: ParseContext<'_>,
182 ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>>
183 where
184 S: Http1Transaction,
185 {
186 loop {
187 match super::role::parse_headers::<S>(
188 &mut self.read_buf,
189 self.partial_len,
190 ParseContext {
191 cached_headers: parse_ctx.cached_headers,
192 req_method: parse_ctx.req_method,
193 h1_parser_config: parse_ctx.h1_parser_config.clone(),
194 #[cfg(all(feature = "server", feature = "runtime"))]
195 h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
196 #[cfg(all(feature = "server", feature = "runtime"))]
197 h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
198 #[cfg(all(feature = "server", feature = "runtime"))]
199 h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
200 preserve_header_case: parse_ctx.preserve_header_case,
201 #[cfg(feature = "ffi")]
202 preserve_header_order: parse_ctx.preserve_header_order,
203 h09_responses: parse_ctx.h09_responses,
204 #[cfg(feature = "ffi")]
205 on_informational: parse_ctx.on_informational,
206 #[cfg(feature = "ffi")]
207 raw_headers: parse_ctx.raw_headers,
208 },
209 )? {
210 Some(msg) => {
211 debug!("parsed {} headers", msg.head.headers.len());
212
213 #[cfg(all(feature = "server", feature = "runtime"))]
214 {
215 *parse_ctx.h1_header_read_timeout_running = false;
216
217 if let Some(h1_header_read_timeout_fut) =
218 parse_ctx.h1_header_read_timeout_fut
219 {
220 // Reset the timer in order to avoid woken up when the timeout finishes
221 h1_header_read_timeout_fut
222 .as_mut()
223 .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60));
224 }
225 }
226 self.partial_len = None;
227 return Poll::Ready(Ok(msg));
228 }
229 None => {
230 let max = self.read_buf_strategy.max();
231 let curr_len = self.read_buf.len();
232 if curr_len >= max {
233 debug!("max_buf_size ({}) reached, closing", max);
234 return Poll::Ready(Err(crate::Error::new_too_large()));
235 }
236
237 #[cfg(all(feature = "server", feature = "runtime"))]
238 if *parse_ctx.h1_header_read_timeout_running {
239 if let Some(h1_header_read_timeout_fut) =
240 parse_ctx.h1_header_read_timeout_fut
241 {
242 if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
243 *parse_ctx.h1_header_read_timeout_running = false;
244
245 tracing::warn!("read header from client timeout");
246 return Poll::Ready(Err(crate::Error::new_header_timeout()));
247 }
248 }
249 }
250 if curr_len > 0 {
251 trace!("partial headers; {} bytes so far", curr_len);
252 self.partial_len = Some(curr_len);
253 } else {
254 // 1xx gobled some bytes
255 self.partial_len = None;
256 }
257 }
258 }
259 if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
260 trace!("parse eof");
261 return Poll::Ready(Err(crate::Error::new_incomplete()));
262 }
263 }
264 }
265
poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>>266 pub(crate) fn poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
267 self.read_blocked = false;
268 let next = self.read_buf_strategy.next();
269 if self.read_buf_remaining_mut() < next {
270 self.read_buf.reserve(next);
271 }
272
273 let dst = self.read_buf.chunk_mut();
274 let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
275 let mut buf = ReadBuf::uninit(dst);
276 match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
277 Poll::Ready(Ok(_)) => {
278 let n = buf.filled().len();
279 trace!("received {} bytes", n);
280 unsafe {
281 // Safety: we just read that many bytes into the
282 // uninitialized part of the buffer, so this is okay.
283 // @tokio pls give me back `poll_read_buf` thanks
284 self.read_buf.advance_mut(n);
285 }
286 self.read_buf_strategy.record(n);
287 Poll::Ready(Ok(n))
288 }
289 Poll::Pending => {
290 self.read_blocked = true;
291 Poll::Pending
292 }
293 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
294 }
295 }
296
into_inner(self) -> (T, Bytes)297 pub(crate) fn into_inner(self) -> (T, Bytes) {
298 (self.io, self.read_buf.freeze())
299 }
300
io_mut(&mut self) -> &mut T301 pub(crate) fn io_mut(&mut self) -> &mut T {
302 &mut self.io
303 }
304
is_read_blocked(&self) -> bool305 pub(crate) fn is_read_blocked(&self) -> bool {
306 self.read_blocked
307 }
308
poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>309 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
310 if self.flush_pipeline && !self.read_buf.is_empty() {
311 Poll::Ready(Ok(()))
312 } else if self.write_buf.remaining() == 0 {
313 Pin::new(&mut self.io).poll_flush(cx)
314 } else {
315 if let WriteStrategy::Flatten = self.write_buf.strategy {
316 return self.poll_flush_flattened(cx);
317 }
318
319 const MAX_WRITEV_BUFS: usize = 64;
320 loop {
321 let n = {
322 let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS];
323 let len = self.write_buf.chunks_vectored(&mut iovs);
324 ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))?
325 };
326 // TODO(eliza): we have to do this manually because
327 // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
328 // `poll_write_buf` comes back, the manual advance will need to leave!
329 self.write_buf.advance(n);
330 debug!("flushed {} bytes", n);
331 if self.write_buf.remaining() == 0 {
332 break;
333 } else if n == 0 {
334 trace!(
335 "write returned zero, but {} bytes remaining",
336 self.write_buf.remaining()
337 );
338 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
339 }
340 }
341 Pin::new(&mut self.io).poll_flush(cx)
342 }
343 }
344
345 /// Specialized version of `flush` when strategy is Flatten.
346 ///
347 /// Since all buffered bytes are flattened into the single headers buffer,
348 /// that skips some bookkeeping around using multiple buffers.
poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>349 fn poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
350 loop {
351 let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?;
352 debug!("flushed {} bytes", n);
353 self.write_buf.headers.advance(n);
354 if self.write_buf.headers.remaining() == 0 {
355 self.write_buf.headers.reset();
356 break;
357 } else if n == 0 {
358 trace!(
359 "write returned zero, but {} bytes remaining",
360 self.write_buf.remaining()
361 );
362 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
363 }
364 }
365 Pin::new(&mut self.io).poll_flush(cx)
366 }
367
368 #[cfg(test)]
flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a369 fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a {
370 futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
371 }
372 }
373
374 // The `B` is a `Buf`, we never project a pin to it
375 impl<T: Unpin, B> Unpin for Buffered<T, B> {}
376
377 // TODO: This trait is old... at least rename to PollBytes or something...
378 pub(crate) trait MemRead {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>379 fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>;
380 }
381
382 impl<T, B> MemRead for Buffered<T, B>
383 where
384 T: AsyncRead + AsyncWrite + Unpin,
385 B: Buf,
386 {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>387 fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
388 if !self.read_buf.is_empty() {
389 let n = std::cmp::min(len, self.read_buf.len());
390 Poll::Ready(Ok(self.read_buf.split_to(n).freeze()))
391 } else {
392 let n = ready!(self.poll_read_from_io(cx))?;
393 Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()))
394 }
395 }
396 }
397
398 #[derive(Clone, Copy, Debug)]
399 enum ReadStrategy {
400 Adaptive {
401 decrease_now: bool,
402 next: usize,
403 max: usize,
404 },
405 #[cfg(feature = "client")]
406 Exact(usize),
407 }
408
409 impl ReadStrategy {
with_max(max: usize) -> ReadStrategy410 fn with_max(max: usize) -> ReadStrategy {
411 ReadStrategy::Adaptive {
412 decrease_now: false,
413 next: INIT_BUFFER_SIZE,
414 max,
415 }
416 }
417
next(&self) -> usize418 fn next(&self) -> usize {
419 match *self {
420 ReadStrategy::Adaptive { next, .. } => next,
421 #[cfg(feature = "client")]
422 ReadStrategy::Exact(exact) => exact,
423 }
424 }
425
max(&self) -> usize426 fn max(&self) -> usize {
427 match *self {
428 ReadStrategy::Adaptive { max, .. } => max,
429 #[cfg(feature = "client")]
430 ReadStrategy::Exact(exact) => exact,
431 }
432 }
433
record(&mut self, bytes_read: usize)434 fn record(&mut self, bytes_read: usize) {
435 match *self {
436 ReadStrategy::Adaptive {
437 ref mut decrease_now,
438 ref mut next,
439 max,
440 ..
441 } => {
442 if bytes_read >= *next {
443 *next = cmp::min(incr_power_of_two(*next), max);
444 *decrease_now = false;
445 } else {
446 let decr_to = prev_power_of_two(*next);
447 if bytes_read < decr_to {
448 if *decrease_now {
449 *next = cmp::max(decr_to, INIT_BUFFER_SIZE);
450 *decrease_now = false;
451 } else {
452 // Decreasing is a two "record" process.
453 *decrease_now = true;
454 }
455 } else {
456 // A read within the current range should cancel
457 // a potential decrease, since we just saw proof
458 // that we still need this size.
459 *decrease_now = false;
460 }
461 }
462 }
463 #[cfg(feature = "client")]
464 ReadStrategy::Exact(_) => (),
465 }
466 }
467 }
468
incr_power_of_two(n: usize) -> usize469 fn incr_power_of_two(n: usize) -> usize {
470 n.saturating_mul(2)
471 }
472
prev_power_of_two(n: usize) -> usize473 fn prev_power_of_two(n: usize) -> usize {
474 // Only way this shift can underflow is if n is less than 4.
475 // (Which would means `usize::MAX >> 64` and underflowed!)
476 debug_assert!(n >= 4);
477 (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1
478 }
479
480 impl Default for ReadStrategy {
default() -> ReadStrategy481 fn default() -> ReadStrategy {
482 ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE)
483 }
484 }
485
486 #[derive(Clone)]
487 pub(crate) struct Cursor<T> {
488 bytes: T,
489 pos: usize,
490 }
491
492 impl<T: AsRef<[u8]>> Cursor<T> {
493 #[inline]
new(bytes: T) -> Cursor<T>494 pub(crate) fn new(bytes: T) -> Cursor<T> {
495 Cursor { bytes, pos: 0 }
496 }
497 }
498
499 impl Cursor<Vec<u8>> {
500 /// If we've advanced the position a bit in this cursor, and wish to
501 /// extend the underlying vector, we may wish to unshift the "read" bytes
502 /// off, and move everything else over.
maybe_unshift(&mut self, additional: usize)503 fn maybe_unshift(&mut self, additional: usize) {
504 if self.pos == 0 {
505 // nothing to do
506 return;
507 }
508
509 if self.bytes.capacity() - self.bytes.len() >= additional {
510 // there's room!
511 return;
512 }
513
514 self.bytes.drain(0..self.pos);
515 self.pos = 0;
516 }
517
reset(&mut self)518 fn reset(&mut self) {
519 self.pos = 0;
520 self.bytes.clear();
521 }
522 }
523
524 impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result525 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
526 f.debug_struct("Cursor")
527 .field("pos", &self.pos)
528 .field("len", &self.bytes.as_ref().len())
529 .finish()
530 }
531 }
532
533 impl<T: AsRef<[u8]>> Buf for Cursor<T> {
534 #[inline]
remaining(&self) -> usize535 fn remaining(&self) -> usize {
536 self.bytes.as_ref().len() - self.pos
537 }
538
539 #[inline]
chunk(&self) -> &[u8]540 fn chunk(&self) -> &[u8] {
541 &self.bytes.as_ref()[self.pos..]
542 }
543
544 #[inline]
advance(&mut self, cnt: usize)545 fn advance(&mut self, cnt: usize) {
546 debug_assert!(self.pos + cnt <= self.bytes.as_ref().len());
547 self.pos += cnt;
548 }
549 }
550
551 // an internal buffer to collect writes before flushes
552 pub(super) struct WriteBuf<B> {
553 /// Re-usable buffer that holds message headers
554 headers: Cursor<Vec<u8>>,
555 max_buf_size: usize,
556 /// Deque of user buffers if strategy is Queue
557 queue: BufList<B>,
558 strategy: WriteStrategy,
559 }
560
561 impl<B: Buf> WriteBuf<B> {
new(strategy: WriteStrategy) -> WriteBuf<B>562 fn new(strategy: WriteStrategy) -> WriteBuf<B> {
563 WriteBuf {
564 headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)),
565 max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
566 queue: BufList::new(),
567 strategy,
568 }
569 }
570 }
571
572 impl<B> WriteBuf<B>
573 where
574 B: Buf,
575 {
set_strategy(&mut self, strategy: WriteStrategy)576 fn set_strategy(&mut self, strategy: WriteStrategy) {
577 self.strategy = strategy;
578 }
579
buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB)580 pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) {
581 debug_assert!(buf.has_remaining());
582 match self.strategy {
583 WriteStrategy::Flatten => {
584 let head = self.headers_mut();
585
586 head.maybe_unshift(buf.remaining());
587 trace!(
588 self.len = head.remaining(),
589 buf.len = buf.remaining(),
590 "buffer.flatten"
591 );
592 //perf: This is a little faster than <Vec as BufMut>>::put,
593 //but accomplishes the same result.
594 loop {
595 let adv = {
596 let slice = buf.chunk();
597 if slice.is_empty() {
598 return;
599 }
600 head.bytes.extend_from_slice(slice);
601 slice.len()
602 };
603 buf.advance(adv);
604 }
605 }
606 WriteStrategy::Queue => {
607 trace!(
608 self.len = self.remaining(),
609 buf.len = buf.remaining(),
610 "buffer.queue"
611 );
612 self.queue.push(buf.into());
613 }
614 }
615 }
616
can_buffer(&self) -> bool617 fn can_buffer(&self) -> bool {
618 match self.strategy {
619 WriteStrategy::Flatten => self.remaining() < self.max_buf_size,
620 WriteStrategy::Queue => {
621 self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size
622 }
623 }
624 }
625
headers_mut(&mut self) -> &mut Cursor<Vec<u8>>626 fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> {
627 debug_assert!(!self.queue.has_remaining());
628 &mut self.headers
629 }
630 }
631
632 impl<B: Buf> fmt::Debug for WriteBuf<B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result633 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634 f.debug_struct("WriteBuf")
635 .field("remaining", &self.remaining())
636 .field("strategy", &self.strategy)
637 .finish()
638 }
639 }
640
641 impl<B: Buf> Buf for WriteBuf<B> {
642 #[inline]
remaining(&self) -> usize643 fn remaining(&self) -> usize {
644 self.headers.remaining() + self.queue.remaining()
645 }
646
647 #[inline]
chunk(&self) -> &[u8]648 fn chunk(&self) -> &[u8] {
649 let headers = self.headers.chunk();
650 if !headers.is_empty() {
651 headers
652 } else {
653 self.queue.chunk()
654 }
655 }
656
657 #[inline]
advance(&mut self, cnt: usize)658 fn advance(&mut self, cnt: usize) {
659 let hrem = self.headers.remaining();
660
661 match hrem.cmp(&cnt) {
662 cmp::Ordering::Equal => self.headers.reset(),
663 cmp::Ordering::Greater => self.headers.advance(cnt),
664 cmp::Ordering::Less => {
665 let qcnt = cnt - hrem;
666 self.headers.reset();
667 self.queue.advance(qcnt);
668 }
669 }
670 }
671
672 #[inline]
chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize673 fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
674 let n = self.headers.chunks_vectored(dst);
675 self.queue.chunks_vectored(&mut dst[n..]) + n
676 }
677 }
678
679 #[derive(Debug)]
680 enum WriteStrategy {
681 Flatten,
682 Queue,
683 }
684
685 #[cfg(test)]
686 mod tests {
687 use super::*;
688 use std::time::Duration;
689
690 use tokio_test::io::Builder as Mock;
691
692 // #[cfg(feature = "nightly")]
693 // use test::Bencher;
694
695 /*
696 impl<T: Read> MemRead for AsyncIo<T> {
697 fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
698 let mut v = vec![0; len];
699 let n = try_nb!(self.read(v.as_mut_slice()));
700 Ok(Async::Ready(BytesMut::from(&v[..n]).freeze()))
701 }
702 }
703 */
704
705 #[tokio::test]
706 #[ignore]
iobuf_write_empty_slice()707 async fn iobuf_write_empty_slice() {
708 // TODO(eliza): can i have writev back pls T_T
709 // // First, let's just check that the Mock would normally return an
710 // // error on an unexpected write, even if the buffer is empty...
711 // let mut mock = Mock::new().build();
712 // futures_util::future::poll_fn(|cx| {
713 // Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
714 // })
715 // .await
716 // .expect_err("should be a broken pipe");
717
718 // // underlying io will return the logic error upon write,
719 // // so we are testing that the io_buf does not trigger a write
720 // // when there is nothing to flush
721 // let mock = Mock::new().build();
722 // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
723 // io_buf.flush().await.expect("should short-circuit flush");
724 }
725
726 #[tokio::test]
parse_reads_until_blocked()727 async fn parse_reads_until_blocked() {
728 use crate::proto::h1::ClientTransaction;
729
730 let _ = pretty_env_logger::try_init();
731 let mock = Mock::new()
732 // Split over multiple reads will read all of it
733 .read(b"HTTP/1.1 200 OK\r\n")
734 .read(b"Server: hyper\r\n")
735 // missing last line ending
736 .wait(Duration::from_secs(1))
737 .build();
738
739 let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
740
741 // We expect a `parse` to be not ready, and so can't await it directly.
742 // Rather, this `poll_fn` will wrap the `Poll` result.
743 futures_util::future::poll_fn(|cx| {
744 let parse_ctx = ParseContext {
745 cached_headers: &mut None,
746 req_method: &mut None,
747 h1_parser_config: Default::default(),
748 #[cfg(feature = "runtime")]
749 h1_header_read_timeout: None,
750 #[cfg(feature = "runtime")]
751 h1_header_read_timeout_fut: &mut None,
752 #[cfg(feature = "runtime")]
753 h1_header_read_timeout_running: &mut false,
754 preserve_header_case: false,
755 #[cfg(feature = "ffi")]
756 preserve_header_order: false,
757 h09_responses: false,
758 #[cfg(feature = "ffi")]
759 on_informational: &mut None,
760 #[cfg(feature = "ffi")]
761 raw_headers: false,
762 };
763 assert!(buffered
764 .parse::<ClientTransaction>(cx, parse_ctx)
765 .is_pending());
766 Poll::Ready(())
767 })
768 .await;
769
770 assert_eq!(
771 buffered.read_buf,
772 b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..]
773 );
774 }
775
776 #[test]
read_strategy_adaptive_increments()777 fn read_strategy_adaptive_increments() {
778 let mut strategy = ReadStrategy::default();
779 assert_eq!(strategy.next(), 8192);
780
781 // Grows if record == next
782 strategy.record(8192);
783 assert_eq!(strategy.next(), 16384);
784
785 strategy.record(16384);
786 assert_eq!(strategy.next(), 32768);
787
788 // Enormous records still increment at same rate
789 strategy.record(::std::usize::MAX);
790 assert_eq!(strategy.next(), 65536);
791
792 let max = strategy.max();
793 while strategy.next() < max {
794 strategy.record(max);
795 }
796
797 assert_eq!(strategy.next(), max, "never goes over max");
798 strategy.record(max + 1);
799 assert_eq!(strategy.next(), max, "never goes over max");
800 }
801
802 #[test]
read_strategy_adaptive_decrements()803 fn read_strategy_adaptive_decrements() {
804 let mut strategy = ReadStrategy::default();
805 strategy.record(8192);
806 assert_eq!(strategy.next(), 16384);
807
808 strategy.record(1);
809 assert_eq!(
810 strategy.next(),
811 16384,
812 "first smaller record doesn't decrement yet"
813 );
814 strategy.record(8192);
815 assert_eq!(strategy.next(), 16384, "record was with range");
816
817 strategy.record(1);
818 assert_eq!(
819 strategy.next(),
820 16384,
821 "in-range record should make this the 'first' again"
822 );
823
824 strategy.record(1);
825 assert_eq!(strategy.next(), 8192, "second smaller record decrements");
826
827 strategy.record(1);
828 assert_eq!(strategy.next(), 8192, "first doesn't decrement");
829 strategy.record(1);
830 assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum");
831 }
832
833 #[test]
read_strategy_adaptive_stays_the_same()834 fn read_strategy_adaptive_stays_the_same() {
835 let mut strategy = ReadStrategy::default();
836 strategy.record(8192);
837 assert_eq!(strategy.next(), 16384);
838
839 strategy.record(8193);
840 assert_eq!(
841 strategy.next(),
842 16384,
843 "first smaller record doesn't decrement yet"
844 );
845
846 strategy.record(8193);
847 assert_eq!(
848 strategy.next(),
849 16384,
850 "with current step does not decrement"
851 );
852 }
853
854 #[test]
read_strategy_adaptive_max_fuzz()855 fn read_strategy_adaptive_max_fuzz() {
856 fn fuzz(max: usize) {
857 let mut strategy = ReadStrategy::with_max(max);
858 while strategy.next() < max {
859 strategy.record(::std::usize::MAX);
860 }
861 let mut next = strategy.next();
862 while next > 8192 {
863 strategy.record(1);
864 strategy.record(1);
865 next = strategy.next();
866 assert!(
867 next.is_power_of_two(),
868 "decrement should be powers of two: {} (max = {})",
869 next,
870 max,
871 );
872 }
873 }
874
875 let mut max = 8192;
876 while max < std::usize::MAX {
877 fuzz(max);
878 max = (max / 2).saturating_mul(3);
879 }
880 fuzz(::std::usize::MAX);
881 }
882
883 #[test]
884 #[should_panic]
885 #[cfg(debug_assertions)] // needs to trigger a debug_assert
write_buf_requires_non_empty_bufs()886 fn write_buf_requires_non_empty_bufs() {
887 let mock = Mock::new().build();
888 let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
889
890 buffered.buffer(Cursor::new(Vec::new()));
891 }
892
893 /*
894 TODO: needs tokio_test::io to allow configure write_buf calls
895 #[test]
896 fn write_buf_queue() {
897 let _ = pretty_env_logger::try_init();
898
899 let mock = AsyncIo::new_buf(vec![], 1024);
900 let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
901
902
903 buffered.headers_buf().extend(b"hello ");
904 buffered.buffer(Cursor::new(b"world, ".to_vec()));
905 buffered.buffer(Cursor::new(b"it's ".to_vec()));
906 buffered.buffer(Cursor::new(b"hyper!".to_vec()));
907 assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
908 buffered.flush().unwrap();
909
910 assert_eq!(buffered.io, b"hello world, it's hyper!");
911 assert_eq!(buffered.io.num_writes(), 1);
912 assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
913 }
914 */
915
916 #[tokio::test]
write_buf_flatten()917 async fn write_buf_flatten() {
918 let _ = pretty_env_logger::try_init();
919
920 let mock = Mock::new().write(b"hello world, it's hyper!").build();
921
922 let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
923 buffered.write_buf.set_strategy(WriteStrategy::Flatten);
924
925 buffered.headers_buf().extend(b"hello ");
926 buffered.buffer(Cursor::new(b"world, ".to_vec()));
927 buffered.buffer(Cursor::new(b"it's ".to_vec()));
928 buffered.buffer(Cursor::new(b"hyper!".to_vec()));
929 assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
930
931 buffered.flush().await.expect("flush");
932 }
933
934 #[test]
write_buf_flatten_partially_flushed()935 fn write_buf_flatten_partially_flushed() {
936 let _ = pretty_env_logger::try_init();
937
938 let b = |s: &str| Cursor::new(s.as_bytes().to_vec());
939
940 let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten);
941
942 write_buf.buffer(b("hello "));
943 write_buf.buffer(b("world, "));
944
945 assert_eq!(write_buf.chunk(), b"hello world, ");
946
947 // advance most of the way, but not all
948 write_buf.advance(11);
949
950 assert_eq!(write_buf.chunk(), b", ");
951 assert_eq!(write_buf.headers.pos, 11);
952 assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE);
953
954 // there's still room in the headers buffer, so just push on the end
955 write_buf.buffer(b("it's hyper!"));
956
957 assert_eq!(write_buf.chunk(), b", it's hyper!");
958 assert_eq!(write_buf.headers.pos, 11);
959
960 let rem1 = write_buf.remaining();
961 let cap = write_buf.headers.bytes.capacity();
962
963 // but when this would go over capacity, don't copy the old bytes
964 write_buf.buffer(Cursor::new(vec![b'X'; cap]));
965 assert_eq!(write_buf.remaining(), cap + rem1);
966 assert_eq!(write_buf.headers.pos, 0);
967 }
968
969 #[tokio::test]
write_buf_queue_disable_auto()970 async fn write_buf_queue_disable_auto() {
971 let _ = pretty_env_logger::try_init();
972
973 let mock = Mock::new()
974 .write(b"hello ")
975 .write(b"world, ")
976 .write(b"it's ")
977 .write(b"hyper!")
978 .build();
979
980 let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
981 buffered.write_buf.set_strategy(WriteStrategy::Queue);
982
983 // we have 4 buffers, and vec IO disabled, but explicitly said
984 // don't try to auto detect (via setting strategy above)
985
986 buffered.headers_buf().extend(b"hello ");
987 buffered.buffer(Cursor::new(b"world, ".to_vec()));
988 buffered.buffer(Cursor::new(b"it's ".to_vec()));
989 buffered.buffer(Cursor::new(b"hyper!".to_vec()));
990 assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
991
992 buffered.flush().await.expect("flush");
993
994 assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
995 }
996
997 // #[cfg(feature = "nightly")]
998 // #[bench]
999 // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) {
1000 // let s = "Hello, World!";
1001 // b.bytes = s.len() as u64;
1002
1003 // let mut write_buf = WriteBuf::<bytes::Bytes>::new();
1004 // write_buf.set_strategy(WriteStrategy::Flatten);
1005 // b.iter(|| {
1006 // let chunk = bytes::Bytes::from(s);
1007 // write_buf.buffer(chunk);
1008 // ::test::black_box(&write_buf);
1009 // write_buf.headers.bytes.clear();
1010 // })
1011 // }
1012 }
1013