• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate
2 //! `AsyncRead` and `AsyncWrite` handles.
3 //!
4 //! To restore this read/write object from its `split::ReadHalf` and
5 //! `split::WriteHalf` use `unsplit`.
6 
7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
8 
9 use std::cell::UnsafeCell;
10 use std::fmt;
11 use std::io;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicBool;
14 use std::sync::atomic::Ordering::{Acquire, Release};
15 use std::sync::Arc;
16 use std::task::{Context, Poll};
17 
18 cfg_io_util! {
19     /// The readable half of a value returned from [`split`](split()).
20     pub struct ReadHalf<T> {
21         inner: Arc<Inner<T>>,
22     }
23 
24     /// The writable half of a value returned from [`split`](split()).
25     pub struct WriteHalf<T> {
26         inner: Arc<Inner<T>>,
27     }
28 
29     /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
30     /// `AsyncRead` and `AsyncWrite` handles.
31     ///
32     /// To restore this read/write object from its `ReadHalf` and
33     /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()).
34     pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
35     where
36         T: AsyncRead + AsyncWrite,
37     {
38         let is_write_vectored = stream.is_write_vectored();
39 
40         let inner = Arc::new(Inner {
41             locked: AtomicBool::new(false),
42             stream: UnsafeCell::new(stream),
43             is_write_vectored,
44         });
45 
46         let rd = ReadHalf {
47             inner: inner.clone(),
48         };
49 
50         let wr = WriteHalf { inner };
51 
52         (rd, wr)
53     }
54 }
55 
56 struct Inner<T> {
57     locked: AtomicBool,
58     stream: UnsafeCell<T>,
59     is_write_vectored: bool,
60 }
61 
62 struct Guard<'a, T> {
63     inner: &'a Inner<T>,
64 }
65 
66 impl<T> ReadHalf<T> {
67     /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same
68     /// stream.
is_pair_of(&self, other: &WriteHalf<T>) -> bool69     pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
70         other.is_pair_of(self)
71     }
72 
73     /// Reunites with a previously split `WriteHalf`.
74     ///
75     /// # Panics
76     ///
77     /// If this `ReadHalf` and the given `WriteHalf` do not originate from the
78     /// same `split` operation this method will panic.
79     /// This can be checked ahead of time by comparing the stream ID
80     /// of the two halves.
81     #[track_caller]
unsplit(self, wr: WriteHalf<T>) -> T where T: Unpin,82     pub fn unsplit(self, wr: WriteHalf<T>) -> T
83     where
84         T: Unpin,
85     {
86         if self.is_pair_of(&wr) {
87             drop(wr);
88 
89             let inner = Arc::try_unwrap(self.inner)
90                 .ok()
91                 .expect("`Arc::try_unwrap` failed");
92 
93             inner.stream.into_inner()
94         } else {
95             panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
96         }
97     }
98 }
99 
100 impl<T> WriteHalf<T> {
101     /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same
102     /// stream.
is_pair_of(&self, other: &ReadHalf<T>) -> bool103     pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
104         Arc::ptr_eq(&self.inner, &other.inner)
105     }
106 }
107 
108 impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>109     fn poll_read(
110         self: Pin<&mut Self>,
111         cx: &mut Context<'_>,
112         buf: &mut ReadBuf<'_>,
113     ) -> Poll<io::Result<()>> {
114         let mut inner = ready!(self.inner.poll_lock(cx));
115         inner.stream_pin().poll_read(cx, buf)
116     }
117 }
118 
119 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>120     fn poll_write(
121         self: Pin<&mut Self>,
122         cx: &mut Context<'_>,
123         buf: &[u8],
124     ) -> Poll<Result<usize, io::Error>> {
125         let mut inner = ready!(self.inner.poll_lock(cx));
126         inner.stream_pin().poll_write(cx, buf)
127     }
128 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>129     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
130         let mut inner = ready!(self.inner.poll_lock(cx));
131         inner.stream_pin().poll_flush(cx)
132     }
133 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>134     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
135         let mut inner = ready!(self.inner.poll_lock(cx));
136         inner.stream_pin().poll_shutdown(cx)
137     }
138 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>139     fn poll_write_vectored(
140         self: Pin<&mut Self>,
141         cx: &mut Context<'_>,
142         bufs: &[io::IoSlice<'_>],
143     ) -> Poll<Result<usize, io::Error>> {
144         let mut inner = ready!(self.inner.poll_lock(cx));
145         inner.stream_pin().poll_write_vectored(cx, bufs)
146     }
147 
is_write_vectored(&self) -> bool148     fn is_write_vectored(&self) -> bool {
149         self.inner.is_write_vectored
150     }
151 }
152 
153 impl<T> Inner<T> {
poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>>154     fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> {
155         if self
156             .locked
157             .compare_exchange(false, true, Acquire, Acquire)
158             .is_ok()
159         {
160             Poll::Ready(Guard { inner: self })
161         } else {
162             // Spin... but investigate a better strategy
163 
164             std::thread::yield_now();
165             cx.waker().wake_by_ref();
166 
167             Poll::Pending
168         }
169     }
170 }
171 
172 impl<T> Guard<'_, T> {
stream_pin(&mut self) -> Pin<&mut T>173     fn stream_pin(&mut self) -> Pin<&mut T> {
174         // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
175         // exclusion.
176         unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) }
177     }
178 }
179 
180 impl<T> Drop for Guard<'_, T> {
drop(&mut self)181     fn drop(&mut self) {
182         self.inner.locked.store(false, Release);
183     }
184 }
185 
186 unsafe impl<T: Send> Send for ReadHalf<T> {}
187 unsafe impl<T: Send> Send for WriteHalf<T> {}
188 unsafe impl<T: Sync> Sync for ReadHalf<T> {}
189 unsafe impl<T: Sync> Sync for WriteHalf<T> {}
190 
191 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result192     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
193         fmt.debug_struct("split::ReadHalf").finish()
194     }
195 }
196 
197 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result198     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
199         fmt.debug_struct("split::WriteHalf").finish()
200     }
201 }
202