• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! `TcpStream` owned split support.
2 //!
3 //! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4 //! with the `TcpStream::into_split` method.  `OwnedReadHalf` implements
5 //! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6 //!
7 //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8 //! split has no associated overhead and enforces all invariants at the type
9 //! level.
10 
11 use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
12 use crate::net::TcpStream;
13 
14 use std::error::Error;
15 use std::future::poll_fn;
16 use std::net::{Shutdown, SocketAddr};
17 use std::pin::Pin;
18 use std::sync::Arc;
19 use std::task::{Context, Poll};
20 use std::{fmt, io};
21 
22 cfg_io_util! {
23     use bytes::BufMut;
24 }
25 
26 /// Owned read half of a [`TcpStream`], created by [`into_split`].
27 ///
28 /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
29 /// on the [`AsyncReadExt`] trait.
30 ///
31 /// [`TcpStream`]: TcpStream
32 /// [`into_split`]: TcpStream::into_split()
33 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
34 #[derive(Debug)]
35 pub struct OwnedReadHalf {
36     inner: Arc<TcpStream>,
37 }
38 
39 /// Owned write half of a [`TcpStream`], created by [`into_split`].
40 ///
41 /// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
42 /// shut down the TCP stream in the write direction.  Dropping the write half
43 /// will also shut down the write half of the TCP stream.
44 ///
45 /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
46 /// on the [`AsyncWriteExt`] trait.
47 ///
48 /// [`TcpStream`]: TcpStream
49 /// [`into_split`]: TcpStream::into_split()
50 /// [`AsyncWrite`]: trait@crate::io::AsyncWrite
51 /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
52 /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
53 #[derive(Debug)]
54 pub struct OwnedWriteHalf {
55     inner: Arc<TcpStream>,
56     shutdown_on_drop: bool,
57 }
58 
split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf)59 pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
60     let arc = Arc::new(stream);
61     let read = OwnedReadHalf {
62         inner: Arc::clone(&arc),
63     };
64     let write = OwnedWriteHalf {
65         inner: arc,
66         shutdown_on_drop: true,
67     };
68     (read, write)
69 }
70 
reunite( read: OwnedReadHalf, write: OwnedWriteHalf, ) -> Result<TcpStream, ReuniteError>71 pub(crate) fn reunite(
72     read: OwnedReadHalf,
73     write: OwnedWriteHalf,
74 ) -> Result<TcpStream, ReuniteError> {
75     if Arc::ptr_eq(&read.inner, &write.inner) {
76         write.forget();
77         // This unwrap cannot fail as the api does not allow creating more than two Arcs,
78         // and we just dropped the other half.
79         Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
80     } else {
81         Err(ReuniteError(read, write))
82     }
83 }
84 
85 /// Error indicating that two halves were not from the same socket, and thus could
86 /// not be reunited.
87 #[derive(Debug)]
88 pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
89 
90 impl fmt::Display for ReuniteError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result91     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92         write!(
93             f,
94             "tried to reunite halves that are not from the same socket"
95         )
96     }
97 }
98 
99 impl Error for ReuniteError {}
100 
101 impl OwnedReadHalf {
102     /// Attempts to put the two halves of a `TcpStream` back together and
103     /// recover the original socket. Succeeds only if the two halves
104     /// originated from the same call to [`into_split`].
105     ///
106     /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError>107     pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
108         reunite(self, other)
109     }
110 
111     /// Attempt to receive data on the socket, without removing that data from
112     /// the queue, registering the current task for wakeup if data is not yet
113     /// available.
114     ///
115     /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
116     /// `Waker` from the `Context` passed to the most recent call is scheduled
117     /// to receive a wakeup.
118     ///
119     /// See the [`TcpStream::poll_peek`] level documentation for more details.
120     ///
121     /// # Examples
122     ///
123     /// ```no_run
124     /// use tokio::io::{self, ReadBuf};
125     /// use tokio::net::TcpStream;
126     ///
127     /// use std::future::poll_fn;
128     ///
129     /// #[tokio::main]
130     /// async fn main() -> io::Result<()> {
131     ///     let stream = TcpStream::connect("127.0.0.1:8000").await?;
132     ///     let (mut read_half, _) = stream.into_split();
133     ///     let mut buf = [0; 10];
134     ///     let mut buf = ReadBuf::new(&mut buf);
135     ///
136     ///     poll_fn(|cx| {
137     ///         read_half.poll_peek(cx, &mut buf)
138     ///     }).await?;
139     ///
140     ///     Ok(())
141     /// }
142     /// ```
143     ///
144     /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<usize>>145     pub fn poll_peek(
146         &mut self,
147         cx: &mut Context<'_>,
148         buf: &mut ReadBuf<'_>,
149     ) -> Poll<io::Result<usize>> {
150         self.inner.poll_peek(cx, buf)
151     }
152 
153     /// Receives data on the socket from the remote address to which it is
154     /// connected, without removing that data from the queue. On success,
155     /// returns the number of bytes peeked.
156     ///
157     /// See the [`TcpStream::peek`] level documentation for more details.
158     ///
159     /// [`TcpStream::peek`]: TcpStream::peek
160     ///
161     /// # Examples
162     ///
163     /// ```no_run
164     /// use tokio::net::TcpStream;
165     /// use tokio::io::AsyncReadExt;
166     /// use std::error::Error;
167     ///
168     /// #[tokio::main]
169     /// async fn main() -> Result<(), Box<dyn Error>> {
170     ///     // Connect to a peer
171     ///     let stream = TcpStream::connect("127.0.0.1:8080").await?;
172     ///     let (mut read_half, _) = stream.into_split();
173     ///
174     ///     let mut b1 = [0; 10];
175     ///     let mut b2 = [0; 10];
176     ///
177     ///     // Peek at the data
178     ///     let n = read_half.peek(&mut b1).await?;
179     ///
180     ///     // Read the data
181     ///     assert_eq!(n, read_half.read(&mut b2[..n]).await?);
182     ///     assert_eq!(&b1[..n], &b2[..n]);
183     ///
184     ///     Ok(())
185     /// }
186     /// ```
187     ///
188     /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
189     ///
190     /// [`read`]: fn@crate::io::AsyncReadExt::read
191     /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
peek(&mut self, buf: &mut [u8]) -> io::Result<usize>192     pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
193         let mut buf = ReadBuf::new(buf);
194         poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
195     }
196 
197     /// Waits for any of the requested ready states.
198     ///
199     /// This function is usually paired with [`try_read()`]. It can be used instead
200     /// of [`readable()`] to check the returned ready set for [`Ready::READABLE`]
201     /// and [`Ready::READ_CLOSED`] events.
202     ///
203     /// The function may complete without the socket being ready. This is a
204     /// false-positive and attempting an operation will return with
205     /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
206     /// [`Ready`] set, so you should always check the returned value and possibly
207     /// wait again if the requested states are not set.
208     ///
209     /// This function is equivalent to [`TcpStream::ready`].
210     ///
211     /// [`try_read()`]: Self::try_read
212     /// [`readable()`]: Self::readable
213     ///
214     /// # Cancel safety
215     ///
216     /// This method is cancel safe. Once a readiness event occurs, the method
217     /// will continue to return immediately until the readiness event is
218     /// consumed by an attempt to read or write that fails with `WouldBlock` or
219     /// `Poll::Pending`.
ready(&self, interest: Interest) -> io::Result<Ready>220     pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
221         self.inner.ready(interest).await
222     }
223 
224     /// Waits for the socket to become readable.
225     ///
226     /// This function is equivalent to `ready(Interest::READABLE)` and is usually
227     /// paired with `try_read()`.
228     ///
229     /// This function is also equivalent to [`TcpStream::ready`].
230     ///
231     /// # Cancel safety
232     ///
233     /// This method is cancel safe. Once a readiness event occurs, the method
234     /// will continue to return immediately until the readiness event is
235     /// consumed by an attempt to read that fails with `WouldBlock` or
236     /// `Poll::Pending`.
readable(&self) -> io::Result<()>237     pub async fn readable(&self) -> io::Result<()> {
238         self.inner.readable().await
239     }
240 
241     /// Tries to read data from the stream into the provided buffer, returning how
242     /// many bytes were read.
243     ///
244     /// Receives any pending data from the socket but does not wait for new data
245     /// to arrive. On success, returns the number of bytes read. Because
246     /// `try_read()` is non-blocking, the buffer does not have to be stored by
247     /// the async task and can exist entirely on the stack.
248     ///
249     /// Usually, [`readable()`] or [`ready()`] is used with this function.
250     ///
251     /// [`readable()`]: Self::readable()
252     /// [`ready()`]: Self::ready()
253     ///
254     /// # Return
255     ///
256     /// If data is successfully read, `Ok(n)` is returned, where `n` is the
257     /// number of bytes read. If `n` is `0`, then it can indicate one of two scenarios:
258     ///
259     /// 1. The stream's read half is closed and will no longer yield data.
260     /// 2. The specified buffer was 0 bytes in length.
261     ///
262     /// If the stream is not ready to read data,
263     /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_read(&self, buf: &mut [u8]) -> io::Result<usize>264     pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
265         self.inner.try_read(buf)
266     }
267 
268     /// Tries to read data from the stream into the provided buffers, returning
269     /// how many bytes were read.
270     ///
271     /// Data is copied to fill each buffer in order, with the final buffer
272     /// written to possibly being only partially filled. This method behaves
273     /// equivalently to a single call to [`try_read()`] with concatenated
274     /// buffers.
275     ///
276     /// Receives any pending data from the socket but does not wait for new data
277     /// to arrive. On success, returns the number of bytes read. Because
278     /// `try_read_vectored()` is non-blocking, the buffer does not have to be
279     /// stored by the async task and can exist entirely on the stack.
280     ///
281     /// Usually, [`readable()`] or [`ready()`] is used with this function.
282     ///
283     /// [`try_read()`]: Self::try_read()
284     /// [`readable()`]: Self::readable()
285     /// [`ready()`]: Self::ready()
286     ///
287     /// # Return
288     ///
289     /// If data is successfully read, `Ok(n)` is returned, where `n` is the
290     /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
291     /// and will no longer yield data. If the stream is not ready to read data
292     /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize>293     pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
294         self.inner.try_read_vectored(bufs)
295     }
296 
297     cfg_io_util! {
298         /// Tries to read data from the stream into the provided buffer, advancing the
299         /// buffer's internal cursor, returning how many bytes were read.
300         ///
301         /// Receives any pending data from the socket but does not wait for new data
302         /// to arrive. On success, returns the number of bytes read. Because
303         /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
304         /// the async task and can exist entirely on the stack.
305         ///
306         /// Usually, [`readable()`] or [`ready()`] is used with this function.
307         ///
308         /// [`readable()`]: Self::readable()
309         /// [`ready()`]: Self::ready()
310         ///
311         /// # Return
312         ///
313         /// If data is successfully read, `Ok(n)` is returned, where `n` is the
314         /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
315         /// and will no longer yield data. If the stream is not ready to read data
316         /// `Err(io::ErrorKind::WouldBlock)` is returned.
317         pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
318             self.inner.try_read_buf(buf)
319         }
320     }
321 
322     /// Returns the remote address that this stream is connected to.
peer_addr(&self) -> io::Result<SocketAddr>323     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
324         self.inner.peer_addr()
325     }
326 
327     /// Returns the local address that this stream is bound to.
local_addr(&self) -> io::Result<SocketAddr>328     pub fn local_addr(&self) -> io::Result<SocketAddr> {
329         self.inner.local_addr()
330     }
331 }
332 
333 impl AsyncRead for OwnedReadHalf {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>334     fn poll_read(
335         self: Pin<&mut Self>,
336         cx: &mut Context<'_>,
337         buf: &mut ReadBuf<'_>,
338     ) -> Poll<io::Result<()>> {
339         self.inner.poll_read_priv(cx, buf)
340     }
341 }
342 
343 impl OwnedWriteHalf {
344     /// Attempts to put the two halves of a `TcpStream` back together and
345     /// recover the original socket. Succeeds only if the two halves
346     /// originated from the same call to [`into_split`].
347     ///
348     /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError>349     pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
350         reunite(other, self)
351     }
352 
353     /// Destroys the write half, but don't close the write half of the stream
354     /// until the read half is dropped. If the read half has already been
355     /// dropped, this closes the stream.
forget(mut self)356     pub fn forget(mut self) {
357         self.shutdown_on_drop = false;
358         drop(self);
359     }
360 
361     /// Waits for any of the requested ready states.
362     ///
363     /// This function is usually paired with [`try_write()`]. It can be used instead
364     /// of [`writable()`] to check the returned ready set for [`Ready::WRITABLE`]
365     /// and [`Ready::WRITE_CLOSED`] events.
366     ///
367     /// The function may complete without the socket being ready. This is a
368     /// false-positive and attempting an operation will return with
369     /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
370     /// [`Ready`] set, so you should always check the returned value and possibly
371     /// wait again if the requested states are not set.
372     ///
373     /// This function is equivalent to [`TcpStream::ready`].
374     ///
375     /// [`try_write()`]: Self::try_write
376     /// [`writable()`]: Self::writable
377     ///
378     /// # Cancel safety
379     ///
380     /// This method is cancel safe. Once a readiness event occurs, the method
381     /// will continue to return immediately until the readiness event is
382     /// consumed by an attempt to read or write that fails with `WouldBlock` or
383     /// `Poll::Pending`.
ready(&self, interest: Interest) -> io::Result<Ready>384     pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
385         self.inner.ready(interest).await
386     }
387 
388     /// Waits for the socket to become writable.
389     ///
390     /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
391     /// paired with `try_write()`.
392     ///
393     /// # Cancel safety
394     ///
395     /// This method is cancel safe. Once a readiness event occurs, the method
396     /// will continue to return immediately until the readiness event is
397     /// consumed by an attempt to write that fails with `WouldBlock` or
398     /// `Poll::Pending`.
writable(&self) -> io::Result<()>399     pub async fn writable(&self) -> io::Result<()> {
400         self.inner.writable().await
401     }
402 
403     /// Tries to write a buffer to the stream, returning how many bytes were
404     /// written.
405     ///
406     /// The function will attempt to write the entire contents of `buf`, but
407     /// only part of the buffer may be written.
408     ///
409     /// This function is usually paired with `writable()`.
410     ///
411     /// # Return
412     ///
413     /// If data is successfully written, `Ok(n)` is returned, where `n` is the
414     /// number of bytes written. If the stream is not ready to write data,
415     /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_write(&self, buf: &[u8]) -> io::Result<usize>416     pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
417         self.inner.try_write(buf)
418     }
419 
420     /// Tries to write several buffers to the stream, returning how many bytes
421     /// were written.
422     ///
423     /// Data is written from each buffer in order, with the final buffer read
424     /// from possible being only partially consumed. This method behaves
425     /// equivalently to a single call to [`try_write()`] with concatenated
426     /// buffers.
427     ///
428     /// This function is usually paired with `writable()`.
429     ///
430     /// [`try_write()`]: Self::try_write()
431     ///
432     /// # Return
433     ///
434     /// If data is successfully written, `Ok(n)` is returned, where `n` is the
435     /// number of bytes written. If the stream is not ready to write data,
436     /// `Err(io::ErrorKind::WouldBlock)` is returned.
try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>437     pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
438         self.inner.try_write_vectored(bufs)
439     }
440 
441     /// Returns the remote address that this stream is connected to.
peer_addr(&self) -> io::Result<SocketAddr>442     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
443         self.inner.peer_addr()
444     }
445 
446     /// Returns the local address that this stream is bound to.
local_addr(&self) -> io::Result<SocketAddr>447     pub fn local_addr(&self) -> io::Result<SocketAddr> {
448         self.inner.local_addr()
449     }
450 }
451 
452 impl Drop for OwnedWriteHalf {
drop(&mut self)453     fn drop(&mut self) {
454         if self.shutdown_on_drop {
455             let _ = self.inner.shutdown_std(Shutdown::Write);
456         }
457     }
458 }
459 
460 impl AsyncWrite for OwnedWriteHalf {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>461     fn poll_write(
462         self: Pin<&mut Self>,
463         cx: &mut Context<'_>,
464         buf: &[u8],
465     ) -> Poll<io::Result<usize>> {
466         self.inner.poll_write_priv(cx, buf)
467     }
468 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>469     fn poll_write_vectored(
470         self: Pin<&mut Self>,
471         cx: &mut Context<'_>,
472         bufs: &[io::IoSlice<'_>],
473     ) -> Poll<io::Result<usize>> {
474         self.inner.poll_write_vectored_priv(cx, bufs)
475     }
476 
is_write_vectored(&self) -> bool477     fn is_write_vectored(&self) -> bool {
478         self.inner.is_write_vectored()
479     }
480 
481     #[inline]
poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>482     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
483         // tcp flush is a no-op
484         Poll::Ready(Ok(()))
485     }
486 
487     // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>488     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
489         let res = self.inner.shutdown_std(Shutdown::Write);
490         if res.is_ok() {
491             Pin::into_inner(self).shutdown_on_drop = false;
492         }
493         res.into()
494     }
495 }
496 
497 impl AsRef<TcpStream> for OwnedReadHalf {
as_ref(&self) -> &TcpStream498     fn as_ref(&self) -> &TcpStream {
499         &self.inner
500     }
501 }
502 
503 impl AsRef<TcpStream> for OwnedWriteHalf {
as_ref(&self) -> &TcpStream504     fn as_ref(&self) -> &TcpStream {
505         &self.inner
506     }
507 }
508