• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::copy::CopyBuffer;
2 
3 use crate::io::{AsyncRead, AsyncWrite};
4 
5 use std::future::Future;
6 use std::io;
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 
10 enum TransferState {
11     Running(CopyBuffer),
12     ShuttingDown(u64),
13     Done(u64),
14 }
15 
16 struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
17     a: &'a mut A,
18     b: &'a mut B,
19     a_to_b: TransferState,
20     b_to_a: TransferState,
21 }
22 
transfer_one_direction<A, B>( cx: &mut Context<'_>, state: &mut TransferState, r: &mut A, w: &mut B, ) -> Poll<io::Result<u64>> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, B: AsyncRead + AsyncWrite + Unpin + ?Sized,23 fn transfer_one_direction<A, B>(
24     cx: &mut Context<'_>,
25     state: &mut TransferState,
26     r: &mut A,
27     w: &mut B,
28 ) -> Poll<io::Result<u64>>
29 where
30     A: AsyncRead + AsyncWrite + Unpin + ?Sized,
31     B: AsyncRead + AsyncWrite + Unpin + ?Sized,
32 {
33     let mut r = Pin::new(r);
34     let mut w = Pin::new(w);
35 
36     loop {
37         match state {
38             TransferState::Running(buf) => {
39                 let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
40                 *state = TransferState::ShuttingDown(count);
41             }
42             TransferState::ShuttingDown(count) => {
43                 ready!(w.as_mut().poll_shutdown(cx))?;
44 
45                 *state = TransferState::Done(*count);
46             }
47             TransferState::Done(count) => return Poll::Ready(Ok(*count)),
48         }
49     }
50 }
51 
52 impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
53 where
54     A: AsyncRead + AsyncWrite + Unpin + ?Sized,
55     B: AsyncRead + AsyncWrite + Unpin + ?Sized,
56 {
57     type Output = io::Result<(u64, u64)>;
58 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>59     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60         // Unpack self into mut refs to each field to avoid borrow check issues.
61         let CopyBidirectional {
62             a,
63             b,
64             a_to_b,
65             b_to_a,
66         } = &mut *self;
67 
68         let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
69         let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;
70 
71         // It is not a problem if ready! returns early because transfer_one_direction for the
72         // other direction will keep returning TransferState::Done(count) in future calls to poll
73         let a_to_b = ready!(a_to_b);
74         let b_to_a = ready!(b_to_a);
75 
76         Poll::Ready(Ok((a_to_b, b_to_a)))
77     }
78 }
79 
80 /// Copies data in both directions between `a` and `b`.
81 ///
82 /// This function returns a future that will read from both streams,
83 /// writing any data read to the opposing stream.
84 /// This happens in both directions concurrently.
85 ///
86 /// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
87 /// the other, and reading from that stream will stop. Copying of data in
88 /// the other direction will continue.
89 ///
90 /// The future will complete successfully once both directions of communication has been shut down.
91 /// A direction is shut down when the reader reports EOF,
92 /// at which point [`shutdown()`] is called on the corresponding writer. When finished,
93 /// it will return a tuple of the number of bytes copied from a to b
94 /// and the number of bytes copied from b to a, in that order.
95 ///
96 /// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
97 ///
98 /// # Errors
99 ///
100 /// The future will immediately return an error if any IO operation on `a`
101 /// or `b` returns an error. Some data read from either stream may be lost (not
102 /// written to the other stream) in this case.
103 ///
104 /// # Return value
105 ///
106 /// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
107 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, B: AsyncRead + AsyncWrite + Unpin + ?Sized,108 pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
109 where
110     A: AsyncRead + AsyncWrite + Unpin + ?Sized,
111     B: AsyncRead + AsyncWrite + Unpin + ?Sized,
112 {
113     CopyBidirectional {
114         a,
115         b,
116         a_to_b: TransferState::Running(CopyBuffer::new()),
117         b_to_a: TransferState::Running(CopyBuffer::new()),
118     }
119     .await
120 }
121