1 //! In-process memory IO types.
2
3 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4 use crate::loom::sync::Mutex;
5
6 use bytes::{Buf, BytesMut};
7 use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, Poll, Waker},
11 };
12
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Closing a `DuplexStream`
20 ///
21 /// If one end of the `DuplexStream` channel is dropped, any pending reads on
22 /// the other side will continue to read data until the buffer is drained, then
23 /// they will signal EOF by returning 0 bytes. Any writes to the other side,
24 /// including pending ones (that are waiting for free space in the buffer) will
25 /// return `Err(BrokenPipe)` immediately.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// # async fn ex() -> std::io::Result<()> {
31 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32 /// let (mut client, mut server) = tokio::io::duplex(64);
33 ///
34 /// client.write_all(b"ping").await?;
35 ///
36 /// let mut buf = [0u8; 4];
37 /// server.read_exact(&mut buf).await?;
38 /// assert_eq!(&buf, b"ping");
39 ///
40 /// server.write_all(b"pong").await?;
41 ///
42 /// client.read_exact(&mut buf).await?;
43 /// assert_eq!(&buf, b"pong");
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[derive(Debug)]
48 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49 pub struct DuplexStream {
50 read: Arc<Mutex<Pipe>>,
51 write: Arc<Mutex<Pipe>>,
52 }
53
54 /// A unidirectional IO over a piece of memory.
55 ///
56 /// Data can be written to the pipe, and reading will return that data.
57 #[derive(Debug)]
58 struct Pipe {
59 /// The buffer storing the bytes written, also read from.
60 ///
61 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62 /// functionality already. Additionally, it can try to copy data in the
63 /// same buffer if there read index has advanced far enough.
64 buffer: BytesMut,
65 /// Determines if the write side has been closed.
66 is_closed: bool,
67 /// The maximum amount of bytes that can be written before returning
68 /// `Poll::Pending`.
69 max_buf_size: usize,
70 /// If the `read` side has been polled and is pending, this is the waker
71 /// for that parked task.
72 read_waker: Option<Waker>,
73 /// If the `write` side has filled the `max_buf_size` and returned
74 /// `Poll::Pending`, this is the waker for that parked task.
75 write_waker: Option<Waker>,
76 }
77
78 // ===== impl DuplexStream =====
79
80 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81 ///
82 /// The `max_buf_size` argument is the maximum amount of bytes that can be
83 /// written to a side before the write returns `Poll::Pending`.
84 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)85 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86 let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
87 let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
88
89 (
90 DuplexStream {
91 read: one.clone(),
92 write: two.clone(),
93 },
94 DuplexStream {
95 read: two,
96 write: one,
97 },
98 )
99 }
100
101 impl AsyncRead for DuplexStream {
102 // Previous rustc required this `self` to be `mut`, even though newer
103 // versions recognize it isn't needed to call `lock()`. So for
104 // compatibility, we include the `mut` and `allow` the lint.
105 //
106 // See https://github.com/rust-lang/rust/issues/73592
107 #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut task::Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<std::io::Result<()>> {
113 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114 }
115 }
116
117 impl AsyncWrite for DuplexStream {
118 #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut task::Context<'_>,
122 buf: &[u8],
123 ) -> Poll<std::io::Result<usize>> {
124 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125 }
126
127 #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>128 fn poll_flush(
129 mut self: Pin<&mut Self>,
130 cx: &mut task::Context<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.write.lock()).poll_flush(cx)
133 }
134
135 #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>136 fn poll_shutdown(
137 mut self: Pin<&mut Self>,
138 cx: &mut task::Context<'_>,
139 ) -> Poll<std::io::Result<()>> {
140 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
141 }
142 }
143
144 impl Drop for DuplexStream {
drop(&mut self)145 fn drop(&mut self) {
146 // notify the other side of the closure
147 self.write.lock().close_write();
148 self.read.lock().close_read();
149 }
150 }
151
152 // ===== impl Pipe =====
153
154 impl Pipe {
new(max_buf_size: usize) -> Self155 fn new(max_buf_size: usize) -> Self {
156 Pipe {
157 buffer: BytesMut::new(),
158 is_closed: false,
159 max_buf_size,
160 read_waker: None,
161 write_waker: None,
162 }
163 }
164
close_write(&mut self)165 fn close_write(&mut self) {
166 self.is_closed = true;
167 // needs to notify any readers that no more data will come
168 if let Some(waker) = self.read_waker.take() {
169 waker.wake();
170 }
171 }
172
close_read(&mut self)173 fn close_read(&mut self) {
174 self.is_closed = true;
175 // needs to notify any writers that they have to abort
176 if let Some(waker) = self.write_waker.take() {
177 waker.wake();
178 }
179 }
180
poll_read_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>181 fn poll_read_internal(
182 mut self: Pin<&mut Self>,
183 cx: &mut task::Context<'_>,
184 buf: &mut ReadBuf<'_>,
185 ) -> Poll<std::io::Result<()>> {
186 if self.buffer.has_remaining() {
187 let max = self.buffer.remaining().min(buf.remaining());
188 buf.put_slice(&self.buffer[..max]);
189 self.buffer.advance(max);
190 if max > 0 {
191 // The passed `buf` might have been empty, don't wake up if
192 // no bytes have been moved.
193 if let Some(waker) = self.write_waker.take() {
194 waker.wake();
195 }
196 }
197 Poll::Ready(Ok(()))
198 } else if self.is_closed {
199 Poll::Ready(Ok(()))
200 } else {
201 self.read_waker = Some(cx.waker().clone());
202 Poll::Pending
203 }
204 }
205
poll_write_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>206 fn poll_write_internal(
207 mut self: Pin<&mut Self>,
208 cx: &mut task::Context<'_>,
209 buf: &[u8],
210 ) -> Poll<std::io::Result<usize>> {
211 if self.is_closed {
212 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
213 }
214 let avail = self.max_buf_size - self.buffer.len();
215 if avail == 0 {
216 self.write_waker = Some(cx.waker().clone());
217 return Poll::Pending;
218 }
219
220 let len = buf.len().min(avail);
221 self.buffer.extend_from_slice(&buf[..len]);
222 if let Some(waker) = self.read_waker.take() {
223 waker.wake();
224 }
225 Poll::Ready(Ok(len))
226 }
227 }
228
229 impl AsyncRead for Pipe {
230 cfg_coop! {
231 fn poll_read(
232 self: Pin<&mut Self>,
233 cx: &mut task::Context<'_>,
234 buf: &mut ReadBuf<'_>,
235 ) -> Poll<std::io::Result<()>> {
236 ready!(crate::trace::trace_leaf(cx));
237 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
238
239 let ret = self.poll_read_internal(cx, buf);
240 if ret.is_ready() {
241 coop.made_progress();
242 }
243 ret
244 }
245 }
246
247 cfg_not_coop! {
248 fn poll_read(
249 self: Pin<&mut Self>,
250 cx: &mut task::Context<'_>,
251 buf: &mut ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 ready!(crate::trace::trace_leaf(cx));
254 self.poll_read_internal(cx, buf)
255 }
256 }
257 }
258
259 impl AsyncWrite for Pipe {
260 cfg_coop! {
261 fn poll_write(
262 self: Pin<&mut Self>,
263 cx: &mut task::Context<'_>,
264 buf: &[u8],
265 ) -> Poll<std::io::Result<usize>> {
266 ready!(crate::trace::trace_leaf(cx));
267 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
268
269 let ret = self.poll_write_internal(cx, buf);
270 if ret.is_ready() {
271 coop.made_progress();
272 }
273 ret
274 }
275 }
276
277 cfg_not_coop! {
278 fn poll_write(
279 self: Pin<&mut Self>,
280 cx: &mut task::Context<'_>,
281 buf: &[u8],
282 ) -> Poll<std::io::Result<usize>> {
283 ready!(crate::trace::trace_leaf(cx));
284 self.poll_write_internal(cx, buf)
285 }
286 }
287
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>288 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
289 Poll::Ready(Ok(()))
290 }
291
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>292 fn poll_shutdown(
293 mut self: Pin<&mut Self>,
294 _: &mut task::Context<'_>,
295 ) -> Poll<std::io::Result<()>> {
296 self.close_write();
297 Poll::Ready(Ok(()))
298 }
299 }
300