• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![warn(rust_2018_idioms)]
2 #![cfg(all(feature = "full", not(target_os = "wasi")))]
3 #![cfg(tokio_unstable)]
4 
5 use tokio::io::{AsyncReadExt, AsyncWriteExt};
6 use tokio::net::{TcpListener, TcpStream};
7 use tokio::runtime;
8 use tokio::sync::oneshot;
9 use tokio_test::{assert_err, assert_ok};
10 
11 use futures::future::poll_fn;
12 use std::future::Future;
13 use std::pin::Pin;
14 use std::sync::atomic::AtomicUsize;
15 use std::sync::atomic::Ordering::Relaxed;
16 use std::sync::{mpsc, Arc, Mutex};
17 use std::task::{Context, Poll, Waker};
18 
19 macro_rules! cfg_metrics {
20     ($($t:tt)*) => {
21         #[cfg(tokio_unstable)]
22         {
23             $( $t )*
24         }
25     }
26 }
27 
28 #[test]
single_thread()29 fn single_thread() {
30     // No panic when starting a runtime w/ a single thread
31     let _ = runtime::Builder::new_multi_thread_alt()
32         .enable_all()
33         .worker_threads(1)
34         .build()
35         .unwrap();
36 }
37 
38 #[test]
many_oneshot_futures()39 fn many_oneshot_futures() {
40     // used for notifying the main thread
41     const NUM: usize = 1_000;
42 
43     for _ in 0..5 {
44         let (tx, rx) = mpsc::channel();
45 
46         let rt = rt();
47         let cnt = Arc::new(AtomicUsize::new(0));
48 
49         for _ in 0..NUM {
50             let cnt = cnt.clone();
51             let tx = tx.clone();
52 
53             rt.spawn(async move {
54                 let num = cnt.fetch_add(1, Relaxed) + 1;
55 
56                 if num == NUM {
57                     tx.send(()).unwrap();
58                 }
59             });
60         }
61 
62         rx.recv().unwrap();
63 
64         // Wait for the pool to shutdown
65         drop(rt);
66     }
67 }
68 
69 #[test]
spawn_two()70 fn spawn_two() {
71     let rt = rt();
72 
73     let out = rt.block_on(async {
74         let (tx, rx) = oneshot::channel();
75 
76         tokio::spawn(async move {
77             tokio::spawn(async move {
78                 tx.send("ZOMG").unwrap();
79             });
80         });
81 
82         assert_ok!(rx.await)
83     });
84 
85     assert_eq!(out, "ZOMG");
86 
87     cfg_metrics! {
88         let metrics = rt.metrics();
89         drop(rt);
90         assert_eq!(1, metrics.remote_schedule_count());
91 
92         let mut local = 0;
93         for i in 0..metrics.num_workers() {
94             local += metrics.worker_local_schedule_count(i);
95         }
96 
97         assert_eq!(1, local);
98     }
99 }
100 
101 #[test]
many_multishot_futures()102 fn many_multishot_futures() {
103     const CHAIN: usize = 200;
104     const CYCLES: usize = 5;
105     const TRACKS: usize = 50;
106 
107     for _ in 0..50 {
108         let rt = rt();
109         let mut start_txs = Vec::with_capacity(TRACKS);
110         let mut final_rxs = Vec::with_capacity(TRACKS);
111 
112         for _ in 0..TRACKS {
113             let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
114 
115             for _ in 0..CHAIN {
116                 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
117 
118                 // Forward all the messages
119                 rt.spawn(async move {
120                     while let Some(v) = chain_rx.recv().await {
121                         next_tx.send(v).await.unwrap();
122                     }
123                 });
124 
125                 chain_rx = next_rx;
126             }
127 
128             // This final task cycles if needed
129             let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
130             let cycle_tx = start_tx.clone();
131             let mut rem = CYCLES;
132 
133             rt.spawn(async move {
134                 for _ in 0..CYCLES {
135                     let msg = chain_rx.recv().await.unwrap();
136 
137                     rem -= 1;
138 
139                     if rem == 0 {
140                         final_tx.send(msg).await.unwrap();
141                     } else {
142                         cycle_tx.send(msg).await.unwrap();
143                     }
144                 }
145             });
146 
147             start_txs.push(start_tx);
148             final_rxs.push(final_rx);
149         }
150 
151         {
152             rt.block_on(async move {
153                 for start_tx in start_txs {
154                     start_tx.send("ping").await.unwrap();
155                 }
156 
157                 for mut final_rx in final_rxs {
158                     final_rx.recv().await.unwrap();
159                 }
160             });
161         }
162     }
163 }
164 
165 #[test]
lifo_slot_budget()166 fn lifo_slot_budget() {
167     async fn my_fn() {
168         spawn_another();
169     }
170 
171     fn spawn_another() {
172         tokio::spawn(my_fn());
173     }
174 
175     let rt = runtime::Builder::new_multi_thread_alt()
176         .enable_all()
177         .worker_threads(1)
178         .build()
179         .unwrap();
180 
181     let (send, recv) = oneshot::channel();
182 
183     rt.spawn(async move {
184         tokio::spawn(my_fn());
185         let _ = send.send(());
186     });
187 
188     let _ = rt.block_on(recv);
189 }
190 
191 #[test]
spawn_shutdown()192 fn spawn_shutdown() {
193     let rt = rt();
194     let (tx, rx) = mpsc::channel();
195 
196     rt.block_on(async {
197         tokio::spawn(client_server(tx.clone()));
198     });
199 
200     // Use spawner
201     rt.spawn(client_server(tx));
202 
203     assert_ok!(rx.recv());
204     assert_ok!(rx.recv());
205 
206     drop(rt);
207     assert_err!(rx.try_recv());
208 }
209 
client_server(tx: mpsc::Sender<()>)210 async fn client_server(tx: mpsc::Sender<()>) {
211     let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
212 
213     // Get the assigned address
214     let addr = assert_ok!(server.local_addr());
215 
216     // Spawn the server
217     tokio::spawn(async move {
218         // Accept a socket
219         let (mut socket, _) = server.accept().await.unwrap();
220 
221         // Write some data
222         socket.write_all(b"hello").await.unwrap();
223     });
224 
225     let mut client = TcpStream::connect(&addr).await.unwrap();
226 
227     let mut buf = vec![];
228     client.read_to_end(&mut buf).await.unwrap();
229 
230     assert_eq!(buf, b"hello");
231     tx.send(()).unwrap();
232 }
233 
234 #[test]
drop_threadpool_drops_futures()235 fn drop_threadpool_drops_futures() {
236     for _ in 0..1_000 {
237         let num_inc = Arc::new(AtomicUsize::new(0));
238         let num_dec = Arc::new(AtomicUsize::new(0));
239         let num_drop = Arc::new(AtomicUsize::new(0));
240 
241         struct Never(Arc<AtomicUsize>);
242 
243         impl Future for Never {
244             type Output = ();
245 
246             fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
247                 Poll::Pending
248             }
249         }
250 
251         impl Drop for Never {
252             fn drop(&mut self) {
253                 self.0.fetch_add(1, Relaxed);
254             }
255         }
256 
257         let a = num_inc.clone();
258         let b = num_dec.clone();
259 
260         let rt = runtime::Builder::new_multi_thread_alt()
261             .enable_all()
262             .on_thread_start(move || {
263                 a.fetch_add(1, Relaxed);
264             })
265             .on_thread_stop(move || {
266                 b.fetch_add(1, Relaxed);
267             })
268             .build()
269             .unwrap();
270 
271         rt.spawn(Never(num_drop.clone()));
272 
273         // Wait for the pool to shutdown
274         drop(rt);
275 
276         // Assert that only a single thread was spawned.
277         let a = num_inc.load(Relaxed);
278         assert!(a >= 1);
279 
280         // Assert that all threads shutdown
281         let b = num_dec.load(Relaxed);
282         assert_eq!(a, b);
283 
284         // Assert that the future was dropped
285         let c = num_drop.load(Relaxed);
286         assert_eq!(c, 1);
287     }
288 }
289 
290 #[test]
start_stop_callbacks_called()291 fn start_stop_callbacks_called() {
292     use std::sync::atomic::{AtomicUsize, Ordering};
293 
294     let after_start = Arc::new(AtomicUsize::new(0));
295     let before_stop = Arc::new(AtomicUsize::new(0));
296 
297     let after_inner = after_start.clone();
298     let before_inner = before_stop.clone();
299     let rt = tokio::runtime::Builder::new_multi_thread_alt()
300         .enable_all()
301         .on_thread_start(move || {
302             after_inner.clone().fetch_add(1, Ordering::Relaxed);
303         })
304         .on_thread_stop(move || {
305             before_inner.clone().fetch_add(1, Ordering::Relaxed);
306         })
307         .build()
308         .unwrap();
309 
310     let (tx, rx) = oneshot::channel();
311 
312     rt.spawn(async move {
313         assert_ok!(tx.send(()));
314     });
315 
316     assert_ok!(rt.block_on(rx));
317 
318     drop(rt);
319 
320     assert!(after_start.load(Ordering::Relaxed) > 0);
321     assert!(before_stop.load(Ordering::Relaxed) > 0);
322 }
323 
324 #[test]
blocking_task()325 fn blocking_task() {
326     // used for notifying the main thread
327     const NUM: usize = 1_000;
328 
329     for _ in 0..10 {
330         let (tx, rx) = mpsc::channel();
331 
332         let rt = rt();
333         let cnt = Arc::new(AtomicUsize::new(0));
334 
335         // there are four workers in the pool
336         // so, if we run 4 blocking tasks, we know that handoff must have happened
337         let block = Arc::new(std::sync::Barrier::new(5));
338         for _ in 0..4 {
339             let block = block.clone();
340             rt.spawn(async move {
341                 tokio::task::block_in_place(move || {
342                     block.wait();
343                     block.wait();
344                 })
345             });
346         }
347         block.wait();
348 
349         for _ in 0..NUM {
350             let cnt = cnt.clone();
351             let tx = tx.clone();
352 
353             rt.spawn(async move {
354                 let num = cnt.fetch_add(1, Relaxed) + 1;
355 
356                 if num == NUM {
357                     tx.send(()).unwrap();
358                 }
359             });
360         }
361 
362         rx.recv().unwrap();
363 
364         // Wait for the pool to shutdown
365         block.wait();
366     }
367 }
368 
369 #[test]
multi_threadpool()370 fn multi_threadpool() {
371     use tokio::sync::oneshot;
372 
373     let rt1 = rt();
374     let rt2 = rt();
375 
376     let (tx, rx) = oneshot::channel();
377     let (done_tx, done_rx) = mpsc::channel();
378 
379     rt2.spawn(async move {
380         rx.await.unwrap();
381         done_tx.send(()).unwrap();
382     });
383 
384     rt1.spawn(async move {
385         tx.send(()).unwrap();
386     });
387 
388     done_rx.recv().unwrap();
389 }
390 
391 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
392 // worker. In this case, the remainder of the task is on the runtime worker and
393 // must take part in the cooperative task budgeting system.
394 //
395 // The test ensures that, when this happens, attempting to consume from a
396 // channel yields occasionally even if there are values ready to receive.
397 #[test]
coop_and_block_in_place()398 fn coop_and_block_in_place() {
399     let rt = tokio::runtime::Builder::new_multi_thread_alt()
400         // Setting max threads to 1 prevents another thread from claiming the
401         // runtime worker yielded as part of `block_in_place` and guarantees the
402         // same thread will reclaim the worker at the end of the
403         // `block_in_place` call.
404         .max_blocking_threads(1)
405         .build()
406         .unwrap();
407 
408     rt.block_on(async move {
409         let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
410 
411         // Fill the channel
412         for _ in 0..1024 {
413             tx.send(()).await.unwrap();
414         }
415 
416         drop(tx);
417 
418         tokio::spawn(async move {
419             // Block in place without doing anything
420             tokio::task::block_in_place(|| {});
421 
422             // Receive all the values, this should trigger a `Pending` as the
423             // coop limit will be reached.
424             poll_fn(|cx| {
425                 while let Poll::Ready(v) = {
426                     tokio::pin! {
427                         let fut = rx.recv();
428                     }
429 
430                     Pin::new(&mut fut).poll(cx)
431                 } {
432                     if v.is_none() {
433                         panic!("did not yield");
434                     }
435                 }
436 
437                 Poll::Ready(())
438             })
439             .await
440         })
441         .await
442         .unwrap();
443     });
444 }
445 
446 #[test]
yield_after_block_in_place()447 fn yield_after_block_in_place() {
448     let rt = tokio::runtime::Builder::new_multi_thread_alt()
449         .worker_threads(1)
450         .build()
451         .unwrap();
452 
453     rt.block_on(async {
454         tokio::spawn(async move {
455             // Block in place then enter a new runtime
456             tokio::task::block_in_place(|| {
457                 let rt = tokio::runtime::Builder::new_current_thread()
458                     .build()
459                     .unwrap();
460 
461                 rt.block_on(async {});
462             });
463 
464             // Yield, then complete
465             tokio::task::yield_now().await;
466         })
467         .await
468         .unwrap()
469     });
470 }
471 
472 // Testing this does not panic
473 #[test]
max_blocking_threads()474 fn max_blocking_threads() {
475     let _rt = tokio::runtime::Builder::new_multi_thread_alt()
476         .max_blocking_threads(1)
477         .build()
478         .unwrap();
479 }
480 
481 #[test]
482 #[should_panic]
max_blocking_threads_set_to_zero()483 fn max_blocking_threads_set_to_zero() {
484     let _rt = tokio::runtime::Builder::new_multi_thread_alt()
485         .max_blocking_threads(0)
486         .build()
487         .unwrap();
488 }
489 
490 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
hang_on_shutdown()491 async fn hang_on_shutdown() {
492     let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
493     tokio::spawn(async move {
494         tokio::task::block_in_place(|| sync_rx.recv().ok());
495     });
496 
497     tokio::spawn(async {
498         tokio::time::sleep(std::time::Duration::from_secs(2)).await;
499         drop(sync_tx);
500     });
501     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
502 }
503 
504 /// Demonstrates tokio-rs/tokio#3869
505 #[test]
wake_during_shutdown()506 fn wake_during_shutdown() {
507     struct Shared {
508         waker: Option<Waker>,
509     }
510 
511     struct MyFuture {
512         shared: Arc<Mutex<Shared>>,
513         put_waker: bool,
514     }
515 
516     impl MyFuture {
517         fn new() -> (Self, Self) {
518             let shared = Arc::new(Mutex::new(Shared { waker: None }));
519             let f1 = MyFuture {
520                 shared: shared.clone(),
521                 put_waker: true,
522             };
523             let f2 = MyFuture {
524                 shared,
525                 put_waker: false,
526             };
527             (f1, f2)
528         }
529     }
530 
531     impl Future for MyFuture {
532         type Output = ();
533 
534         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
535             let me = Pin::into_inner(self);
536             let mut lock = me.shared.lock().unwrap();
537             if me.put_waker {
538                 lock.waker = Some(cx.waker().clone());
539             }
540             Poll::Pending
541         }
542     }
543 
544     impl Drop for MyFuture {
545         fn drop(&mut self) {
546             let mut lock = self.shared.lock().unwrap();
547             if !self.put_waker {
548                 lock.waker.take().unwrap().wake();
549             }
550             drop(lock);
551         }
552     }
553 
554     let rt = tokio::runtime::Builder::new_multi_thread_alt()
555         .worker_threads(1)
556         .enable_all()
557         .build()
558         .unwrap();
559 
560     let (f1, f2) = MyFuture::new();
561 
562     rt.spawn(f1);
563     rt.spawn(f2);
564 
565     rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
566 }
567 
568 #[should_panic]
569 #[tokio::test]
test_block_in_place1()570 async fn test_block_in_place1() {
571     tokio::task::block_in_place(|| {});
572 }
573 
574 #[tokio::test(flavor = "multi_thread")]
test_block_in_place2()575 async fn test_block_in_place2() {
576     tokio::task::block_in_place(|| {});
577 }
578 
579 #[should_panic]
580 #[tokio::main(flavor = "current_thread")]
581 #[test]
test_block_in_place3()582 async fn test_block_in_place3() {
583     tokio::task::block_in_place(|| {});
584 }
585 
586 #[tokio::main]
587 #[test]
test_block_in_place4()588 async fn test_block_in_place4() {
589     tokio::task::block_in_place(|| {});
590 }
591 
592 // Testing the tuning logic is tricky as it is inherently timing based, and more
593 // of a heuristic than an exact behavior. This test checks that the interval
594 // changes over time based on load factors. There are no assertions, completion
595 // is sufficient. If there is a regression, this test will hang. In theory, we
596 // could add limits, but that would be likely to fail on CI.
597 #[test]
598 #[cfg(not(tokio_no_tuning_tests))]
test_tuning()599 fn test_tuning() {
600     use std::sync::atomic::AtomicBool;
601     use std::time::Duration;
602 
603     let rt = runtime::Builder::new_multi_thread_alt()
604         .worker_threads(1)
605         .build()
606         .unwrap();
607 
608     fn iter(flag: Arc<AtomicBool>, counter: Arc<AtomicUsize>, stall: bool) {
609         if flag.load(Relaxed) {
610             if stall {
611                 std::thread::sleep(Duration::from_micros(5));
612             }
613 
614             counter.fetch_add(1, Relaxed);
615             tokio::spawn(async move { iter(flag, counter, stall) });
616         }
617     }
618 
619     let flag = Arc::new(AtomicBool::new(true));
620     let counter = Arc::new(AtomicUsize::new(61));
621     let interval = Arc::new(AtomicUsize::new(61));
622 
623     {
624         let flag = flag.clone();
625         let counter = counter.clone();
626         rt.spawn(async move { iter(flag, counter, true) });
627     }
628 
629     // Now, hammer the injection queue until the interval drops.
630     let mut n = 0;
631     loop {
632         let curr = interval.load(Relaxed);
633 
634         if curr <= 8 {
635             n += 1;
636         } else {
637             n = 0;
638         }
639 
640         // Make sure we get a few good rounds. Jitter in the tuning could result
641         // in one "good" value without being representative of reaching a good
642         // state.
643         if n == 3 {
644             break;
645         }
646 
647         if Arc::strong_count(&interval) < 5_000 {
648             let counter = counter.clone();
649             let interval = interval.clone();
650 
651             rt.spawn(async move {
652                 let prev = counter.swap(0, Relaxed);
653                 interval.store(prev, Relaxed);
654             });
655 
656             std::thread::yield_now();
657         }
658     }
659 
660     flag.store(false, Relaxed);
661 
662     let w = Arc::downgrade(&interval);
663     drop(interval);
664 
665     while w.strong_count() > 0 {
666         std::thread::sleep(Duration::from_micros(500));
667     }
668 
669     // Now, run it again with a faster task
670     let flag = Arc::new(AtomicBool::new(true));
671     // Set it high, we know it shouldn't ever really be this high
672     let counter = Arc::new(AtomicUsize::new(10_000));
673     let interval = Arc::new(AtomicUsize::new(10_000));
674 
675     {
676         let flag = flag.clone();
677         let counter = counter.clone();
678         rt.spawn(async move { iter(flag, counter, false) });
679     }
680 
681     // Now, hammer the injection queue until the interval reaches the expected range.
682     let mut n = 0;
683     loop {
684         let curr = interval.load(Relaxed);
685 
686         if curr <= 1_000 && curr > 32 {
687             n += 1;
688         } else {
689             n = 0;
690         }
691 
692         if n == 3 {
693             break;
694         }
695 
696         if Arc::strong_count(&interval) <= 5_000 {
697             let counter = counter.clone();
698             let interval = interval.clone();
699 
700             rt.spawn(async move {
701                 let prev = counter.swap(0, Relaxed);
702                 interval.store(prev, Relaxed);
703             });
704         }
705 
706         std::thread::yield_now();
707     }
708 
709     flag.store(false, Relaxed);
710 }
711 
rt() -> runtime::Runtime712 fn rt() -> runtime::Runtime {
713     runtime::Builder::new_multi_thread_alt()
714         .enable_all()
715         .build()
716         .unwrap()
717 }
718