• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //     http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 use std::collections::VecDeque;
15 use std::future::Future;
16 use std::mem;
17 use std::pin::Pin;
18 #[cfg(feature = "metrics")]
19 use std::sync::atomic::AtomicU64;
20 use std::sync::atomic::AtomicUsize;
21 use std::sync::atomic::Ordering::{AcqRel, Acquire};
22 use std::sync::{Arc, Condvar, Mutex};
23 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
24 
25 use crate::executor::driver::{Driver, Handle, ParkFlag};
26 use crate::executor::Schedule;
27 use crate::task::{JoinHandle, Task, TaskBuilder, VirtualTableType};
28 
29 // Idle state
30 const IDLE: usize = 0;
31 // Suspended on condvar
32 const PARKED_ON_CONDVAR: usize = 1;
33 // Suspended on driver
34 const PARKED_ON_DRIVER: usize = 2;
35 // notified by the spawned task
36 const NOTIFIED: usize = 3;
37 // notified by the blocked task
38 const NOTIFIED_BLOCK: usize = 4;
39 
40 pub(crate) struct CurrentThreadSpawner {
41     pub(crate) scheduler: Arc<CurrentThreadScheduler>,
42     pub(crate) driver: Arc<Mutex<Driver>>,
43     pub(crate) handle: Arc<Handle>,
44 }
45 
46 #[derive(Default)]
47 pub(crate) struct CurrentThreadScheduler {
48     pub(crate) inner: Mutex<VecDeque<Task>>,
49     pub(crate) parker_list: Mutex<Vec<Arc<Parker>>>,
50     /// Total task count
51     #[cfg(feature = "metrics")]
52     pub(crate) count: AtomicU64,
53 }
54 
55 unsafe impl Sync for CurrentThreadScheduler {}
56 
57 impl Schedule for CurrentThreadScheduler {
58     #[inline]
schedule(&self, task: Task, _lifo: bool)59     fn schedule(&self, task: Task, _lifo: bool) {
60         let mut queue = self.inner.lock().unwrap();
61         #[cfg(feature = "metrics")]
62         self.count.fetch_add(1, AcqRel);
63         queue.push_back(task);
64 
65         let parker_list = self.parker_list.lock().unwrap();
66         for parker in &*parker_list {
67             parker.unpark(false);
68         }
69     }
70 }
71 
72 impl CurrentThreadScheduler {
pop(&self) -> Option<Task>73     fn pop(&self) -> Option<Task> {
74         let mut queue = self.inner.lock().unwrap();
75         queue.pop_front()
76     }
77 }
78 
79 pub(crate) struct Parker {
80     state: AtomicUsize,
81     mutex: Mutex<bool>,
82     condvar: Condvar,
83     driver: Arc<Mutex<Driver>>,
84     handle: Arc<Handle>,
85 }
86 
87 impl Parker {
new(driver: Arc<Mutex<Driver>>, handle: Arc<Handle>) -> Parker88     fn new(driver: Arc<Mutex<Driver>>, handle: Arc<Handle>) -> Parker {
89         Parker {
90             state: AtomicUsize::new(IDLE),
91             mutex: Mutex::new(false),
92             condvar: Condvar::new(),
93             driver,
94             handle,
95         }
96     }
97 
park(&self) -> bool98     fn park(&self) -> bool {
99         let (mut park, mut wake) = (true, false);
100         if let Ok(mut driver) = self.driver.try_lock() {
101             (park, wake) = self.park_on_driver(&mut driver);
102         }
103         if park {
104             self.park_on_condvar()
105         } else {
106             wake
107         }
108     }
109 
park_on_driver(&self, driver: &mut Driver) -> (bool, bool)110     fn park_on_driver(&self, driver: &mut Driver) -> (bool, bool) {
111         match self
112             .state
113             .compare_exchange(IDLE, PARKED_ON_DRIVER, AcqRel, Acquire)
114         {
115             Ok(_) => {}
116             Err(NOTIFIED_BLOCK) | Err(NOTIFIED) => {
117                 return match self.state.swap(IDLE, AcqRel) {
118                     // No need to park on condvar, need to awaken the blocked task.
119                     NOTIFIED_BLOCK => (false, true),
120                     // No need to park on condvar, no need to awaken the blocked task.
121                     NOTIFIED => (false, false),
122                     actual => panic!("invalid park state when notifying; actual = {actual}"),
123                 };
124             }
125             Err(actual) => panic!("inconsistent park state; actual = {actual}"),
126         }
127 
128         let park = match driver.run() {
129             ParkFlag::NotPark => false,
130             ParkFlag::Park => true,
131             ParkFlag::ParkTimeout(_) => false,
132         };
133 
134         match self.state.swap(IDLE, AcqRel) {
135             NOTIFIED => (false, false),
136             NOTIFIED_BLOCK => (false, true),
137             PARKED_ON_DRIVER => (park, false),
138             n => panic!("inconsistent park_timeout state: {n}"),
139         }
140     }
141 
park_on_condvar(&self) -> bool142     fn park_on_condvar(&self) -> bool {
143         let mut lock = self.mutex.lock().unwrap();
144         match self
145             .state
146             .compare_exchange(IDLE, PARKED_ON_CONDVAR, AcqRel, Acquire)
147         {
148             Ok(_) => {}
149             Err(NOTIFIED_BLOCK) | Err(NOTIFIED) => {
150                 return match self.state.swap(IDLE, AcqRel) {
151                     // Need to awaken the blocked task.
152                     NOTIFIED_BLOCK => true,
153                     // No need to awaken the blocked task.
154                     NOTIFIED => false,
155                     actual => panic!("invalid park state when notifying; actual = {actual}"),
156                 };
157             }
158             Err(actual) => panic!("inconsistent park state; actual = {actual}"),
159         }
160 
161         while !*lock {
162             lock = self.condvar.wait(lock).unwrap();
163         }
164         *lock = false;
165 
166         match self.state.swap(IDLE, AcqRel) {
167             NOTIFIED => false,
168             NOTIFIED_BLOCK => true,
169             n => panic!("inconsistent park_timeout state: {n}"),
170         }
171     }
172 
unpark(&self, wake: bool)173     fn unpark(&self, wake: bool) {
174         if wake {
175             match self.state.swap(NOTIFIED_BLOCK, AcqRel) {
176                 IDLE | NOTIFIED | NOTIFIED_BLOCK => {}
177                 PARKED_ON_CONDVAR => {
178                     let mut lock = self.mutex.lock().unwrap();
179                     *lock = true;
180                     mem::drop(lock);
181                     self.condvar.notify_one();
182                 }
183                 PARKED_ON_DRIVER => self.handle.wake(),
184                 actual => panic!("inconsistent state in unpark; actual = {actual}"),
185             }
186         } else {
187             match self.state.swap(NOTIFIED, AcqRel) {
188                 IDLE | NOTIFIED => {}
189                 NOTIFIED_BLOCK => self.unpark(true),
190                 PARKED_ON_CONDVAR => {
191                     let mut lock = self.mutex.lock().unwrap();
192                     *lock = true;
193                     mem::drop(lock);
194                     self.condvar.notify_one();
195                 }
196                 PARKED_ON_DRIVER => self.handle.wake(),
197                 actual => panic!("inconsistent state in unpark; actual = {actual}"),
198             }
199         }
200     }
201 }
202 
waker(parker: Arc<Parker>) -> Waker203 fn waker(parker: Arc<Parker>) -> Waker {
204     let data = Arc::into_raw(parker).cast::<()>();
205     unsafe { Waker::from_raw(RawWaker::new(data, &CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE)) }
206 }
207 
208 static CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE: RawWakerVTable =
209     RawWakerVTable::new(clone, wake, wake_by_ref, drop);
210 
clone(ptr: *const ()) -> RawWaker211 fn clone(ptr: *const ()) -> RawWaker {
212     let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
213 
214     // increment the ref count
215     mem::forget(parker.clone());
216 
217     let data = Arc::into_raw(parker).cast::<()>();
218     RawWaker::new(data, &CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE)
219 }
220 
wake(ptr: *const ())221 fn wake(ptr: *const ()) {
222     let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
223     parker.unpark(true);
224 }
225 
wake_by_ref(ptr: *const ())226 fn wake_by_ref(ptr: *const ()) {
227     let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
228     parker.unpark(true);
229     mem::forget(parker);
230 }
231 
drop(ptr: *const ())232 fn drop(ptr: *const ()) {
233     unsafe { mem::drop(Arc::from_raw(ptr.cast::<Parker>())) };
234 }
235 
236 impl CurrentThreadSpawner {
new() -> Self237     pub(crate) fn new() -> Self {
238         let (handle, driver) = Driver::initialize();
239         Self {
240             scheduler: Default::default(),
241             driver,
242             handle,
243         }
244     }
245 
get_parker(&self) -> Parker246     fn get_parker(&self) -> Parker {
247         Parker::new(self.driver.clone(), self.handle.clone())
248     }
249 
spawn<T>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<T::Output> where T: Future + Send + 'static, T::Output: Send + 'static,250     pub(crate) fn spawn<T>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<T::Output>
251     where
252         T: Future + Send + 'static,
253         T::Output: Send + 'static,
254     {
255         let scheduler = Arc::downgrade(&self.scheduler);
256         let (task, handle) = Task::create_task(builder, scheduler, task, VirtualTableType::Ylong);
257 
258         let mut queue = self.scheduler.inner.lock().unwrap();
259         queue.push_back(task);
260         #[cfg(feature = "metrics")]
261         self.scheduler.count.fetch_add(1, AcqRel);
262 
263         let parker_list = self.scheduler.parker_list.lock().unwrap();
264         for parker in &*parker_list {
265             parker.unpark(false);
266         }
267         handle
268     }
269 
block_on<T>(&self, future: T) -> T::Output where T: Future,270     pub(crate) fn block_on<T>(&self, future: T) -> T::Output
271     where
272         T: Future,
273     {
274         let parker = Arc::new(self.get_parker());
275         let mut parker_list = self.scheduler.parker_list.lock().unwrap();
276         parker_list.push(parker.clone());
277         mem::drop(parker_list);
278 
279         let waker = waker(parker.clone());
280         let mut cx = Context::from_waker(&waker);
281 
282         let mut future = future;
283         let mut future = unsafe { Pin::new_unchecked(&mut future) };
284 
285         let mut wake = true;
286 
287         loop {
288             if wake {
289                 if let Poll::Ready(res) = future.as_mut().poll(&mut cx) {
290                     return res;
291                 }
292             }
293 
294             while let Some(task) = self.scheduler.pop() {
295                 task.run();
296             }
297 
298             wake = parker.park();
299         }
300     }
301 }
302 
303 #[cfg(test)]
304 mod test {
305     macro_rules! cfg_sync {
306         ($($item:item)*) => {
307             $(
308                 #[cfg(feature = "sync")]
309                 $item
310             )*
311         }
312     }
313 
314     use crate::executor::current_thread::CurrentThreadSpawner;
315     use crate::task::{yield_now, TaskBuilder};
316 
317     cfg_sync! {
318         use crate::sync::Waiter;
319         use std::sync::atomic::AtomicUsize;
320         use std::sync::atomic::Ordering::{Acquire, Release};
321         use std::sync::{Condvar, Mutex};
322         use std::sync::Arc;
323 
324         pub(crate) struct Parker {
325             mutex: Mutex<bool>,
326             condvar: Condvar,
327         }
328 
329         impl Parker {
330             fn new() -> Parker {
331                 Parker {
332                     mutex: Mutex::new(false),
333                     condvar: Condvar::new(),
334                 }
335             }
336 
337             fn notified(&self) {
338                 let mut guard = self.mutex.lock().unwrap();
339 
340                 while !*guard {
341                     guard = self.condvar.wait(guard).unwrap();
342                 }
343                 *guard = false;
344             }
345 
346             fn notify_one(&self) {
347                 let mut guard = self.mutex.lock().unwrap();
348                 *guard = true;
349                 drop(guard);
350                 self.condvar.notify_one();
351             }
352         }
353     }
354 
355     cfg_net! {
356         use std::net::SocketAddr;
357         use crate::net::{TcpListener, TcpStream};
358         use crate::io::{AsyncReadExt, AsyncWriteExt};
359 
360         const ADDR: &str = "127.0.0.1:0";
361 
362         pub async fn ylong_tcp_server(tx: crate::sync::oneshot::Sender<SocketAddr>) {
363             let tcp = TcpListener::bind(ADDR).await.unwrap();
364             let addr = tcp.local_addr().unwrap();
365             tx.send(addr).unwrap();
366             let (mut stream, _) = tcp.accept().await.unwrap();
367             for _ in 0..3 {
368                 let mut buf = [0; 100];
369                 stream.read_exact(&mut buf).await.unwrap();
370                 assert_eq!(buf, [3; 100]);
371 
372                 let buf = [2; 100];
373                 stream.write_all(&buf).await.unwrap();
374             }
375         }
376 
377         pub async fn ylong_tcp_client(rx: crate::sync::oneshot::Receiver<SocketAddr>) {
378             let addr = rx.await.unwrap();
379             let mut tcp = TcpStream::connect(addr).await;
380             while tcp.is_err() {
381                 tcp = TcpStream::connect(addr).await;
382             }
383             let mut tcp = tcp.unwrap();
384             for _ in 0..3 {
385                 let buf = [3; 100];
386                 tcp.write_all(&buf).await.unwrap();
387 
388                 let mut buf = [0; 100];
389                 tcp.read_exact(&mut buf).await.unwrap();
390                 assert_eq!(buf, [2; 100]);
391             }
392         }
393     }
394 
395     /// UT test cases for `block_on()`.
396     ///
397     /// # Brief
398     /// 1. Spawn two tasks, check the running status of tasks in the queue when
399     ///    the yield task is blocked on.
400     #[test]
ut_current_thread_block_on()401     fn ut_current_thread_block_on() {
402         let spawner = CurrentThreadSpawner::new();
403         let handle1 = spawner.spawn(&TaskBuilder::default(), async move { 1 });
404         let handle2 = spawner.spawn(&TaskBuilder::default(), async move { 1 });
405         spawner.block_on(yield_now());
406         assert_eq!(spawner.scheduler.inner.lock().unwrap().len(), 0);
407         assert_eq!(spawner.block_on(handle1).unwrap(), 1);
408         assert_eq!(spawner.block_on(handle2).unwrap(), 1);
409     }
410 
411     /// UT test cases for `spawn()` and `block_on()`.
412     ///
413     /// # Brief
414     /// 1. Spawn two tasks before the blocked task running and check the status
415     ///    of two tasks.
416     /// 2. Spawn two tasks after the blocked task running and check the status
417     ///    of two tasks.
418     #[test]
419     #[cfg(feature = "sync")]
ut_current_thread_run_queue()420     fn ut_current_thread_run_queue() {
421         use crate::builder::RuntimeBuilder;
422         let spawner = Arc::new(RuntimeBuilder::new_current_thread().build().unwrap());
423 
424         let finished = Arc::new(AtomicUsize::new(0));
425 
426         let finished_clone = finished.clone();
427         let notify1 = Arc::new(Parker::new());
428         let notify1_clone = notify1.clone();
429         spawner.spawn(async move {
430             finished_clone.fetch_add(1, Release);
431             notify1_clone.notify_one();
432         });
433 
434         let finished_clone = finished.clone();
435         let notify2 = Arc::new(Parker::new());
436         let notify2_clone = notify2.clone();
437         spawner.spawn(async move {
438             finished_clone.fetch_add(1, Release);
439             notify2_clone.notify_one();
440         });
441 
442         let waiter = Arc::new(Waiter::new());
443         let waiter_clone = waiter.clone();
444         let spawner_clone = spawner.clone();
445         let join = std::thread::spawn(move || {
446             spawner_clone.block_on(async move { waiter_clone.wait().await })
447         });
448 
449         notify1.notified();
450         notify2.notified();
451         assert_eq!(finished.load(Acquire), 2);
452 
453         let finished_clone = finished.clone();
454         let notify1 = Arc::new(Parker::new());
455         let notify1_clone = notify1.clone();
456         spawner.spawn(async move {
457             finished_clone.fetch_add(1, Release);
458             notify1_clone.notify_one();
459         });
460 
461         let finished_clone = finished.clone();
462         let notify2 = Arc::new(Parker::new());
463         let notify2_clone = notify2.clone();
464         spawner.spawn(async move {
465             finished_clone.fetch_add(1, Release);
466             notify2_clone.notify_one();
467         });
468 
469         notify1.notified();
470         notify2.notified();
471         assert_eq!(finished.load(Acquire), 4);
472 
473         waiter.wake_one();
474         join.join().unwrap();
475 
476         #[cfg(feature = "net")]
477         crate::executor::worker::CURRENT_WORKER.with(|ctx| {
478             ctx.set(std::ptr::null());
479         });
480     }
481 
482     /// UT test cases for io tasks.
483     ///
484     /// # Brief
485     /// 1. Spawns a tcp server to read and write data for three times.
486     /// 2. Spawns a tcp client to read and write data for three times.
487     #[test]
488     #[cfg(feature = "net")]
ut_current_thread_io()489     fn ut_current_thread_io() {
490         use crate::builder::RuntimeBuilder;
491 
492         let spawner = RuntimeBuilder::new_current_thread().build().unwrap();
493         let (tx, rx) = crate::sync::oneshot::channel();
494         let join_handle = spawner.spawn(ylong_tcp_server(tx));
495 
496         spawner.block_on(ylong_tcp_client(rx));
497         spawner.block_on(join_handle).unwrap();
498 
499         let spawner = RuntimeBuilder::new_current_thread().build().unwrap();
500         let (tx, rx) = crate::sync::oneshot::channel();
501         let join_handle = spawner.spawn(ylong_tcp_client(rx));
502         spawner.block_on(ylong_tcp_server(tx));
503         spawner.block_on(join_handle).unwrap();
504     }
505 }
506