1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3
4 use tokio::net::{TcpListener, TcpStream};
5 use tokio::sync::{mpsc, oneshot};
6 use tokio_test::assert_ok;
7
8 use std::io;
9 use std::net::{IpAddr, SocketAddr};
10
11 macro_rules! test_accept {
12 ($(($ident:ident, $target:expr),)*) => {
13 $(
14 #[tokio::test]
15 async fn $ident() {
16 let listener = assert_ok!(TcpListener::bind($target).await);
17 let addr = listener.local_addr().unwrap();
18
19 let (tx, rx) = oneshot::channel();
20
21 tokio::spawn(async move {
22 let (socket, _) = assert_ok!(listener.accept().await);
23 assert_ok!(tx.send(socket));
24 });
25
26 let cli = assert_ok!(TcpStream::connect(&addr).await);
27 let srv = assert_ok!(rx.await);
28
29 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
30 }
31 )*
32 }
33 }
34
35 test_accept! {
36 (ip_str, "127.0.0.1:0"),
37 (host_str, "localhost:0"),
38 (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
39 (str_port_tuple, ("127.0.0.1", 0)),
40 (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
41 }
42
43 use std::pin::Pin;
44 use std::sync::{
45 atomic::{AtomicUsize, Ordering::SeqCst},
46 Arc,
47 };
48 use std::task::{Context, Poll};
49 use tokio_stream::{Stream, StreamExt};
50
51 struct TrackPolls<'a> {
52 npolls: Arc<AtomicUsize>,
53 listener: &'a mut TcpListener,
54 }
55
56 impl<'a> Stream for TrackPolls<'a> {
57 type Item = io::Result<(TcpStream, SocketAddr)>;
58
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>59 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 self.npolls.fetch_add(1, SeqCst);
61 self.listener.poll_accept(cx).map(Some)
62 }
63 }
64
65 #[tokio::test]
no_extra_poll()66 async fn no_extra_poll() {
67 let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
68 let addr = listener.local_addr().unwrap();
69
70 let (tx, rx) = oneshot::channel();
71 let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
72
73 tokio::spawn(async move {
74 let mut incoming = TrackPolls {
75 npolls: Arc::new(AtomicUsize::new(0)),
76 listener: &mut listener,
77 };
78 assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
79 while incoming.next().await.is_some() {
80 accepted_tx.send(()).unwrap();
81 }
82 });
83
84 let npolls = assert_ok!(rx.await);
85 tokio::task::yield_now().await;
86
87 // should have been polled exactly once: the initial poll
88 assert_eq!(npolls.load(SeqCst), 1);
89
90 let _ = assert_ok!(TcpStream::connect(&addr).await);
91 accepted_rx.recv().await.unwrap();
92
93 // should have been polled twice more: once to yield Some(), then once to yield Pending
94 assert_eq!(npolls.load(SeqCst), 1 + 2);
95 }
96
97 #[tokio::test]
accept_many()98 async fn accept_many() {
99 use futures::future::poll_fn;
100 use std::future::Future;
101 use std::sync::atomic::AtomicBool;
102
103 const N: usize = 50;
104
105 let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
106 let listener = Arc::new(listener);
107 let addr = listener.local_addr().unwrap();
108 let connected = Arc::new(AtomicBool::new(false));
109
110 let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
111 let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
112
113 for _ in 0..N {
114 let listener = listener.clone();
115 let connected = connected.clone();
116 let pending_tx = pending_tx.clone();
117 let notified_tx = notified_tx.clone();
118
119 tokio::spawn(async move {
120 let accept = listener.accept();
121 tokio::pin!(accept);
122
123 let mut polled = false;
124
125 poll_fn(|cx| {
126 if !polled {
127 polled = true;
128 assert!(Pin::new(&mut accept).poll(cx).is_pending());
129 pending_tx.send(()).unwrap();
130 Poll::Pending
131 } else if connected.load(SeqCst) {
132 notified_tx.send(()).unwrap();
133 Poll::Ready(())
134 } else {
135 Poll::Pending
136 }
137 })
138 .await;
139
140 pending_tx.send(()).unwrap();
141 });
142 }
143
144 // Wait for all tasks to have polled at least once
145 for _ in 0..N {
146 pending_rx.recv().await.unwrap();
147 }
148
149 // Establish a TCP connection
150 connected.store(true, SeqCst);
151 let _sock = TcpStream::connect(addr).await.unwrap();
152
153 // Wait for all notifications
154 for _ in 0..N {
155 notified_rx.recv().await.unwrap();
156 }
157 }
158