1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate 2 //! `AsyncRead` and `AsyncWrite` handles. 3 //! 4 //! To restore this read/write object from its `split::ReadHalf` and 5 //! `split::WriteHalf` use `unsplit`. 6 7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 9 use std::cell::UnsafeCell; 10 use std::fmt; 11 use std::io; 12 use std::pin::Pin; 13 use std::sync::atomic::AtomicBool; 14 use std::sync::atomic::Ordering::{Acquire, Release}; 15 use std::sync::Arc; 16 use std::task::{Context, Poll}; 17 18 cfg_io_util! { 19 /// The readable half of a value returned from [`split`](split()). 20 pub struct ReadHalf<T> { 21 inner: Arc<Inner<T>>, 22 } 23 24 /// The writable half of a value returned from [`split`](split()). 25 pub struct WriteHalf<T> { 26 inner: Arc<Inner<T>>, 27 } 28 29 /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate 30 /// `AsyncRead` and `AsyncWrite` handles. 31 /// 32 /// To restore this read/write object from its `ReadHalf` and 33 /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). 34 pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) 35 where 36 T: AsyncRead + AsyncWrite, 37 { 38 let is_write_vectored = stream.is_write_vectored(); 39 40 let inner = Arc::new(Inner { 41 locked: AtomicBool::new(false), 42 stream: UnsafeCell::new(stream), 43 is_write_vectored, 44 }); 45 46 let rd = ReadHalf { 47 inner: inner.clone(), 48 }; 49 50 let wr = WriteHalf { inner }; 51 52 (rd, wr) 53 } 54 } 55 56 struct Inner<T> { 57 locked: AtomicBool, 58 stream: UnsafeCell<T>, 59 is_write_vectored: bool, 60 } 61 62 struct Guard<'a, T> { 63 inner: &'a Inner<T>, 64 } 65 66 impl<T> ReadHalf<T> { 67 /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same 68 /// stream. is_pair_of(&self, other: &WriteHalf<T>) -> bool69 pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { 70 other.is_pair_of(self) 71 } 72 73 /// Reunites with a previously split `WriteHalf`. 74 /// 75 /// # Panics 76 /// 77 /// If this `ReadHalf` and the given `WriteHalf` do not originate from the 78 /// same `split` operation this method will panic. 79 /// This can be checked ahead of time by comparing the stream ID 80 /// of the two halves. 81 #[track_caller] unsplit(self, wr: WriteHalf<T>) -> T where T: Unpin,82 pub fn unsplit(self, wr: WriteHalf<T>) -> T 83 where 84 T: Unpin, 85 { 86 if self.is_pair_of(&wr) { 87 drop(wr); 88 89 let inner = Arc::try_unwrap(self.inner) 90 .ok() 91 .expect("`Arc::try_unwrap` failed"); 92 93 inner.stream.into_inner() 94 } else { 95 panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") 96 } 97 } 98 } 99 100 impl<T> WriteHalf<T> { 101 /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same 102 /// stream. is_pair_of(&self, other: &ReadHalf<T>) -> bool103 pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { 104 Arc::ptr_eq(&self.inner, &other.inner) 105 } 106 } 107 108 impl<T: AsyncRead> AsyncRead for ReadHalf<T> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>109 fn poll_read( 110 self: Pin<&mut Self>, 111 cx: &mut Context<'_>, 112 buf: &mut ReadBuf<'_>, 113 ) -> Poll<io::Result<()>> { 114 let mut inner = ready!(self.inner.poll_lock(cx)); 115 inner.stream_pin().poll_read(cx, buf) 116 } 117 } 118 119 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>120 fn poll_write( 121 self: Pin<&mut Self>, 122 cx: &mut Context<'_>, 123 buf: &[u8], 124 ) -> Poll<Result<usize, io::Error>> { 125 let mut inner = ready!(self.inner.poll_lock(cx)); 126 inner.stream_pin().poll_write(cx, buf) 127 } 128 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>129 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 130 let mut inner = ready!(self.inner.poll_lock(cx)); 131 inner.stream_pin().poll_flush(cx) 132 } 133 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>134 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 135 let mut inner = ready!(self.inner.poll_lock(cx)); 136 inner.stream_pin().poll_shutdown(cx) 137 } 138 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>139 fn poll_write_vectored( 140 self: Pin<&mut Self>, 141 cx: &mut Context<'_>, 142 bufs: &[io::IoSlice<'_>], 143 ) -> Poll<Result<usize, io::Error>> { 144 let mut inner = ready!(self.inner.poll_lock(cx)); 145 inner.stream_pin().poll_write_vectored(cx, bufs) 146 } 147 is_write_vectored(&self) -> bool148 fn is_write_vectored(&self) -> bool { 149 self.inner.is_write_vectored 150 } 151 } 152 153 impl<T> Inner<T> { poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>>154 fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> { 155 if self 156 .locked 157 .compare_exchange(false, true, Acquire, Acquire) 158 .is_ok() 159 { 160 Poll::Ready(Guard { inner: self }) 161 } else { 162 // Spin... but investigate a better strategy 163 164 std::thread::yield_now(); 165 cx.waker().wake_by_ref(); 166 167 Poll::Pending 168 } 169 } 170 } 171 172 impl<T> Guard<'_, T> { stream_pin(&mut self) -> Pin<&mut T>173 fn stream_pin(&mut self) -> Pin<&mut T> { 174 // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual 175 // exclusion. 176 unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } 177 } 178 } 179 180 impl<T> Drop for Guard<'_, T> { drop(&mut self)181 fn drop(&mut self) { 182 self.inner.locked.store(false, Release); 183 } 184 } 185 186 unsafe impl<T: Send> Send for ReadHalf<T> {} 187 unsafe impl<T: Send> Send for WriteHalf<T> {} 188 unsafe impl<T: Sync> Sync for ReadHalf<T> {} 189 unsafe impl<T: Sync> Sync for WriteHalf<T> {} 190 191 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result192 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 193 fmt.debug_struct("split::ReadHalf").finish() 194 } 195 } 196 197 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result198 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 199 fmt.debug_struct("split::WriteHalf").finish() 200 } 201 } 202