1 //! `TcpStream` owned split support.
2 //!
3 //! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4 //! with the `TcpStream::into_split` method. `OwnedReadHalf` implements
5 //! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6 //!
7 //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8 //! split has no associated overhead and enforces all invariants at the type
9 //! level.
10
11 use crate::future::poll_fn;
12 use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
13 use crate::net::TcpStream;
14
15 use std::error::Error;
16 use std::net::{Shutdown, SocketAddr};
17 use std::pin::Pin;
18 use std::sync::Arc;
19 use std::task::{Context, Poll};
20 use std::{fmt, io};
21
22 cfg_io_util! {
23 use bytes::BufMut;
24 }
25
26 /// Owned read half of a [`TcpStream`], created by [`into_split`].
27 ///
28 /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
29 /// on the [`AsyncReadExt`] trait.
30 ///
31 /// [`TcpStream`]: TcpStream
32 /// [`into_split`]: TcpStream::into_split()
33 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
34 #[derive(Debug)]
35 pub struct OwnedReadHalf {
36 inner: Arc<TcpStream>,
37 }
38
39 /// Owned write half of a [`TcpStream`], created by [`into_split`].
40 ///
41 /// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
42 /// shut down the TCP stream in the write direction. Dropping the write half
43 /// will also shut down the write half of the TCP stream.
44 ///
45 /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
46 /// on the [`AsyncWriteExt`] trait.
47 ///
48 /// [`TcpStream`]: TcpStream
49 /// [`into_split`]: TcpStream::into_split()
50 /// [`AsyncWrite`]: trait@crate::io::AsyncWrite
51 /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
52 /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
53 #[derive(Debug)]
54 pub struct OwnedWriteHalf {
55 inner: Arc<TcpStream>,
56 shutdown_on_drop: bool,
57 }
58
split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf)59 pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
60 let arc = Arc::new(stream);
61 let read = OwnedReadHalf {
62 inner: Arc::clone(&arc),
63 };
64 let write = OwnedWriteHalf {
65 inner: arc,
66 shutdown_on_drop: true,
67 };
68 (read, write)
69 }
70
reunite( read: OwnedReadHalf, write: OwnedWriteHalf, ) -> Result<TcpStream, ReuniteError>71 pub(crate) fn reunite(
72 read: OwnedReadHalf,
73 write: OwnedWriteHalf,
74 ) -> Result<TcpStream, ReuniteError> {
75 if Arc::ptr_eq(&read.inner, &write.inner) {
76 write.forget();
77 // This unwrap cannot fail as the api does not allow creating more than two Arcs,
78 // and we just dropped the other half.
79 Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
80 } else {
81 Err(ReuniteError(read, write))
82 }
83 }
84
85 /// Error indicating that two halves were not from the same socket, and thus could
86 /// not be reunited.
87 #[derive(Debug)]
88 pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
89
90 impl fmt::Display for ReuniteError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 write!(
93 f,
94 "tried to reunite halves that are not from the same socket"
95 )
96 }
97 }
98
99 impl Error for ReuniteError {}
100
101 impl OwnedReadHalf {
102 /// Attempts to put the two halves of a `TcpStream` back together and
103 /// recover the original socket. Succeeds only if the two halves
104 /// originated from the same call to [`into_split`].
105 ///
106 /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError>107 pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
108 reunite(self, other)
109 }
110
111 /// Attempt to receive data on the socket, without removing that data from
112 /// the queue, registering the current task for wakeup if data is not yet
113 /// available.
114 ///
115 /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
116 /// `Waker` from the `Context` passed to the most recent call is scheduled
117 /// to receive a wakeup.
118 ///
119 /// See the [`TcpStream::poll_peek`] level documentation for more details.
120 ///
121 /// # Examples
122 ///
123 /// ```no_run
124 /// use tokio::io::{self, ReadBuf};
125 /// use tokio::net::TcpStream;
126 ///
127 /// use futures::future::poll_fn;
128 ///
129 /// #[tokio::main]
130 /// async fn main() -> io::Result<()> {
131 /// let stream = TcpStream::connect("127.0.0.1:8000").await?;
132 /// let (mut read_half, _) = stream.into_split();
133 /// let mut buf = [0; 10];
134 /// let mut buf = ReadBuf::new(&mut buf);
135 ///
136 /// poll_fn(|cx| {
137 /// read_half.poll_peek(cx, &mut buf)
138 /// }).await?;
139 ///
140 /// Ok(())
141 /// }
142 /// ```
143 ///
144 /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<usize>>145 pub fn poll_peek(
146 &mut self,
147 cx: &mut Context<'_>,
148 buf: &mut ReadBuf<'_>,
149 ) -> Poll<io::Result<usize>> {
150 self.inner.poll_peek(cx, buf)
151 }
152
153 /// Receives data on the socket from the remote address to which it is
154 /// connected, without removing that data from the queue. On success,
155 /// returns the number of bytes peeked.
156 ///
157 /// See the [`TcpStream::peek`] level documentation for more details.
158 ///
159 /// [`TcpStream::peek`]: TcpStream::peek
160 ///
161 /// # Examples
162 ///
163 /// ```no_run
164 /// use tokio::net::TcpStream;
165 /// use tokio::io::AsyncReadExt;
166 /// use std::error::Error;
167 ///
168 /// #[tokio::main]
169 /// async fn main() -> Result<(), Box<dyn Error>> {
170 /// // Connect to a peer
171 /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
172 /// let (mut read_half, _) = stream.into_split();
173 ///
174 /// let mut b1 = [0; 10];
175 /// let mut b2 = [0; 10];
176 ///
177 /// // Peek at the data
178 /// let n = read_half.peek(&mut b1).await?;
179 ///
180 /// // Read the data
181 /// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
182 /// assert_eq!(&b1[..n], &b2[..n]);
183 ///
184 /// Ok(())
185 /// }
186 /// ```
187 ///
188 /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
189 ///
190 /// [`read`]: fn@crate::io::AsyncReadExt::read
191 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
peek(&mut self, buf: &mut [u8]) -> io::Result<usize>192 pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
193 let mut buf = ReadBuf::new(buf);
194 poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
195 }
196
197 /// Waits for any of the requested ready states.
198 ///
199 /// This function is usually paired with `try_read()` or `try_write()`. It
200 /// can be used to concurrently read / write to the same socket on a single
201 /// task without splitting the socket.
202 ///
203 /// This function is equivalent to [`TcpStream::ready`].
204 ///
205 /// # Cancel safety
206 ///
207 /// This method is cancel safe. Once a readiness event occurs, the method
208 /// will continue to return immediately until the readiness event is
209 /// consumed by an attempt to read or write that fails with `WouldBlock` or
210 /// `Poll::Pending`.
ready(&self, interest: Interest) -> io::Result<Ready>211 pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
212 self.inner.ready(interest).await
213 }
214
215 /// Waits for the socket to become readable.
216 ///
217 /// This function is equivalent to `ready(Interest::READABLE)` and is usually
218 /// paired with `try_read()`.
219 ///
220 /// This function is also equivalent to [`TcpStream::ready`].
221 ///
222 /// # Cancel safety
223 ///
224 /// This method is cancel safe. Once a readiness event occurs, the method
225 /// will continue to return immediately until the readiness event is
226 /// consumed by an attempt to read that fails with `WouldBlock` or
227 /// `Poll::Pending`.
readable(&self) -> io::Result<()>228 pub async fn readable(&self) -> io::Result<()> {
229 self.inner.readable().await
230 }
231
232 /// Tries to read data from the stream into the provided buffer, returning how
233 /// many bytes were read.
234 ///
235 /// Receives any pending data from the socket but does not wait for new data
236 /// to arrive. On success, returns the number of bytes read. Because
237 /// `try_read()` is non-blocking, the buffer does not have to be stored by
238 /// the async task and can exist entirely on the stack.
239 ///
240 /// Usually, [`readable()`] or [`ready()`] is used with this function.
241 ///
242 /// [`readable()`]: Self::readable()
243 /// [`ready()`]: Self::ready()
244 ///
245 /// # Return
246 ///
247 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
248 /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
249 /// and will no longer yield data. If the stream is not ready to read data
250 /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_read(&self, buf: &mut [u8]) -> io::Result<usize>251 pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
252 self.inner.try_read(buf)
253 }
254
255 /// Tries to read data from the stream into the provided buffers, returning
256 /// how many bytes were read.
257 ///
258 /// Data is copied to fill each buffer in order, with the final buffer
259 /// written to possibly being only partially filled. This method behaves
260 /// equivalently to a single call to [`try_read()`] with concatenated
261 /// buffers.
262 ///
263 /// Receives any pending data from the socket but does not wait for new data
264 /// to arrive. On success, returns the number of bytes read. Because
265 /// `try_read_vectored()` is non-blocking, the buffer does not have to be
266 /// stored by the async task and can exist entirely on the stack.
267 ///
268 /// Usually, [`readable()`] or [`ready()`] is used with this function.
269 ///
270 /// [`try_read()`]: Self::try_read()
271 /// [`readable()`]: Self::readable()
272 /// [`ready()`]: Self::ready()
273 ///
274 /// # Return
275 ///
276 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
277 /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
278 /// and will no longer yield data. If the stream is not ready to read data
279 /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize>280 pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
281 self.inner.try_read_vectored(bufs)
282 }
283
284 cfg_io_util! {
285 /// Tries to read data from the stream into the provided buffer, advancing the
286 /// buffer's internal cursor, returning how many bytes were read.
287 ///
288 /// Receives any pending data from the socket but does not wait for new data
289 /// to arrive. On success, returns the number of bytes read. Because
290 /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
291 /// the async task and can exist entirely on the stack.
292 ///
293 /// Usually, [`readable()`] or [`ready()`] is used with this function.
294 ///
295 /// [`readable()`]: Self::readable()
296 /// [`ready()`]: Self::ready()
297 ///
298 /// # Return
299 ///
300 /// If data is successfully read, `Ok(n)` is returned, where `n` is the
301 /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
302 /// and will no longer yield data. If the stream is not ready to read data
303 /// `Err(io::ErrorKind::WouldBlock)` is returned.
304 pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
305 self.inner.try_read_buf(buf)
306 }
307 }
308
309 /// Returns the remote address that this stream is connected to.
peer_addr(&self) -> io::Result<SocketAddr>310 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
311 self.inner.peer_addr()
312 }
313
314 /// Returns the local address that this stream is bound to.
local_addr(&self) -> io::Result<SocketAddr>315 pub fn local_addr(&self) -> io::Result<SocketAddr> {
316 self.inner.local_addr()
317 }
318 }
319
320 impl AsyncRead for OwnedReadHalf {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>321 fn poll_read(
322 self: Pin<&mut Self>,
323 cx: &mut Context<'_>,
324 buf: &mut ReadBuf<'_>,
325 ) -> Poll<io::Result<()>> {
326 self.inner.poll_read_priv(cx, buf)
327 }
328 }
329
330 impl OwnedWriteHalf {
331 /// Attempts to put the two halves of a `TcpStream` back together and
332 /// recover the original socket. Succeeds only if the two halves
333 /// originated from the same call to [`into_split`].
334 ///
335 /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError>336 pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
337 reunite(other, self)
338 }
339
340 /// Destroys the write half, but don't close the write half of the stream
341 /// until the read half is dropped. If the read half has already been
342 /// dropped, this closes the stream.
forget(mut self)343 pub fn forget(mut self) {
344 self.shutdown_on_drop = false;
345 drop(self);
346 }
347
348 /// Waits for any of the requested ready states.
349 ///
350 /// This function is usually paired with `try_read()` or `try_write()`. It
351 /// can be used to concurrently read / write to the same socket on a single
352 /// task without splitting the socket.
353 ///
354 /// This function is equivalent to [`TcpStream::ready`].
355 ///
356 /// # Cancel safety
357 ///
358 /// This method is cancel safe. Once a readiness event occurs, the method
359 /// will continue to return immediately until the readiness event is
360 /// consumed by an attempt to read or write that fails with `WouldBlock` or
361 /// `Poll::Pending`.
ready(&self, interest: Interest) -> io::Result<Ready>362 pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
363 self.inner.ready(interest).await
364 }
365
366 /// Waits for the socket to become writable.
367 ///
368 /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
369 /// paired with `try_write()`.
370 ///
371 /// # Cancel safety
372 ///
373 /// This method is cancel safe. Once a readiness event occurs, the method
374 /// will continue to return immediately until the readiness event is
375 /// consumed by an attempt to write that fails with `WouldBlock` or
376 /// `Poll::Pending`.
writable(&self) -> io::Result<()>377 pub async fn writable(&self) -> io::Result<()> {
378 self.inner.writable().await
379 }
380
381 /// Tries to write a buffer to the stream, returning how many bytes were
382 /// written.
383 ///
384 /// The function will attempt to write the entire contents of `buf`, but
385 /// only part of the buffer may be written.
386 ///
387 /// This function is usually paired with `writable()`.
388 ///
389 /// # Return
390 ///
391 /// If data is successfully written, `Ok(n)` is returned, where `n` is the
392 /// number of bytes written. If the stream is not ready to write data,
393 /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_write(&self, buf: &[u8]) -> io::Result<usize>394 pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
395 self.inner.try_write(buf)
396 }
397
398 /// Tries to write several buffers to the stream, returning how many bytes
399 /// were written.
400 ///
401 /// Data is written from each buffer in order, with the final buffer read
402 /// from possible being only partially consumed. This method behaves
403 /// equivalently to a single call to [`try_write()`] with concatenated
404 /// buffers.
405 ///
406 /// This function is usually paired with `writable()`.
407 ///
408 /// [`try_write()`]: Self::try_write()
409 ///
410 /// # Return
411 ///
412 /// If data is successfully written, `Ok(n)` is returned, where `n` is the
413 /// number of bytes written. If the stream is not ready to write data,
414 /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>415 pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
416 self.inner.try_write_vectored(bufs)
417 }
418
419 /// Returns the remote address that this stream is connected to.
peer_addr(&self) -> io::Result<SocketAddr>420 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
421 self.inner.peer_addr()
422 }
423
424 /// Returns the local address that this stream is bound to.
local_addr(&self) -> io::Result<SocketAddr>425 pub fn local_addr(&self) -> io::Result<SocketAddr> {
426 self.inner.local_addr()
427 }
428 }
429
430 impl Drop for OwnedWriteHalf {
drop(&mut self)431 fn drop(&mut self) {
432 if self.shutdown_on_drop {
433 let _ = self.inner.shutdown_std(Shutdown::Write);
434 }
435 }
436 }
437
438 impl AsyncWrite for OwnedWriteHalf {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>439 fn poll_write(
440 self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 buf: &[u8],
443 ) -> Poll<io::Result<usize>> {
444 self.inner.poll_write_priv(cx, buf)
445 }
446
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>447 fn poll_write_vectored(
448 self: Pin<&mut Self>,
449 cx: &mut Context<'_>,
450 bufs: &[io::IoSlice<'_>],
451 ) -> Poll<io::Result<usize>> {
452 self.inner.poll_write_vectored_priv(cx, bufs)
453 }
454
is_write_vectored(&self) -> bool455 fn is_write_vectored(&self) -> bool {
456 self.inner.is_write_vectored()
457 }
458
459 #[inline]
poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>460 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
461 // tcp flush is a no-op
462 Poll::Ready(Ok(()))
463 }
464
465 // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>466 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
467 let res = self.inner.shutdown_std(Shutdown::Write);
468 if res.is_ok() {
469 Pin::into_inner(self).shutdown_on_drop = false;
470 }
471 res.into()
472 }
473 }
474
475 impl AsRef<TcpStream> for OwnedReadHalf {
as_ref(&self) -> &TcpStream476 fn as_ref(&self) -> &TcpStream {
477 &*self.inner
478 }
479 }
480
481 impl AsRef<TcpStream> for OwnedWriteHalf {
as_ref(&self) -> &TcpStream482 fn as_ref(&self) -> &TcpStream {
483 &*self.inner
484 }
485 }
486