1 use futures_io::{self as io, AsyncBufRead, AsyncRead, AsyncWrite}; 2 use pin_project::pin_project; 3 use std::{ 4 cmp, 5 pin::Pin, 6 task::{Context, Poll}, 7 }; 8 9 /// I/O wrapper that limits the number of bytes written or read on each call. 10 /// 11 /// See the [`limited`] and [`limited_write`] methods. 12 /// 13 /// [`limited`]: super::AsyncReadTestExt::limited 14 /// [`limited_write`]: super::AsyncWriteTestExt::limited_write 15 #[pin_project] 16 #[derive(Debug)] 17 pub struct Limited<Io> { 18 #[pin] 19 io: Io, 20 limit: usize, 21 } 22 23 impl<Io> Limited<Io> { new(io: Io, limit: usize) -> Self24 pub(crate) fn new(io: Io, limit: usize) -> Self { 25 Self { io, limit } 26 } 27 28 /// Acquires a reference to the underlying I/O object that this adaptor is 29 /// wrapping. get_ref(&self) -> &Io30 pub fn get_ref(&self) -> &Io { 31 &self.io 32 } 33 34 /// Acquires a mutable reference to the underlying I/O object that this 35 /// adaptor is wrapping. get_mut(&mut self) -> &mut Io36 pub fn get_mut(&mut self) -> &mut Io { 37 &mut self.io 38 } 39 40 /// Acquires a pinned mutable reference to the underlying I/O object that 41 /// this adaptor is wrapping. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io>42 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io> { 43 self.project().io 44 } 45 46 /// Consumes this adaptor returning the underlying I/O object. into_inner(self) -> Io47 pub fn into_inner(self) -> Io { 48 self.io 49 } 50 } 51 52 impl<W: AsyncWrite> AsyncWrite for Limited<W> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>53 fn poll_write( 54 self: Pin<&mut Self>, 55 cx: &mut Context<'_>, 56 buf: &[u8], 57 ) -> Poll<io::Result<usize>> { 58 let this = self.project(); 59 this.io.poll_write(cx, &buf[..cmp::min(*this.limit, buf.len())]) 60 } 61 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>62 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 63 self.project().io.poll_flush(cx) 64 } 65 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>66 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 67 self.project().io.poll_close(cx) 68 } 69 } 70 71 impl<R: AsyncRead> AsyncRead for Limited<R> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>72 fn poll_read( 73 self: Pin<&mut Self>, 74 cx: &mut Context<'_>, 75 buf: &mut [u8], 76 ) -> Poll<io::Result<usize>> { 77 let this = self.project(); 78 let limit = cmp::min(*this.limit, buf.len()); 79 this.io.poll_read(cx, &mut buf[..limit]) 80 } 81 } 82 83 impl<R: AsyncBufRead> AsyncBufRead for Limited<R> { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>84 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 85 self.project().io.poll_fill_buf(cx) 86 } 87 consume(self: Pin<&mut Self>, amount: usize)88 fn consume(self: Pin<&mut Self>, amount: usize) { 89 self.project().io.consume(amount) 90 } 91 } 92