1 // Copyright (c) 2023 Huawei Device Co., Ltd.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13
14 use std::collections::VecDeque;
15 use std::future::Future;
16 use std::mem;
17 use std::pin::Pin;
18 #[cfg(feature = "metrics")]
19 use std::sync::atomic::AtomicU64;
20 use std::sync::atomic::AtomicUsize;
21 use std::sync::atomic::Ordering::{AcqRel, Acquire};
22 use std::sync::{Arc, Condvar, Mutex};
23 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
24
25 use crate::executor::driver::{Driver, Handle, ParkFlag};
26 use crate::executor::Schedule;
27 use crate::task::{JoinHandle, Task, TaskBuilder, VirtualTableType};
28
29 // Idle state
30 const IDLE: usize = 0;
31 // Suspended on condvar
32 const PARKED_ON_CONDVAR: usize = 1;
33 // Suspended on driver
34 const PARKED_ON_DRIVER: usize = 2;
35 // notified by the spawned task
36 const NOTIFIED: usize = 3;
37 // notified by the blocked task
38 const NOTIFIED_BLOCK: usize = 4;
39
40 pub(crate) struct CurrentThreadSpawner {
41 pub(crate) scheduler: Arc<CurrentThreadScheduler>,
42 pub(crate) driver: Arc<Mutex<Driver>>,
43 pub(crate) handle: Arc<Handle>,
44 }
45
46 #[derive(Default)]
47 pub(crate) struct CurrentThreadScheduler {
48 pub(crate) inner: Mutex<VecDeque<Task>>,
49 pub(crate) parker_list: Mutex<Vec<Arc<Parker>>>,
50 /// Total task count
51 #[cfg(feature = "metrics")]
52 pub(crate) count: AtomicU64,
53 }
54
55 unsafe impl Sync for CurrentThreadScheduler {}
56
57 impl Schedule for CurrentThreadScheduler {
58 #[inline]
schedule(&self, task: Task, _lifo: bool)59 fn schedule(&self, task: Task, _lifo: bool) {
60 let mut queue = self.inner.lock().unwrap();
61 #[cfg(feature = "metrics")]
62 self.count.fetch_add(1, AcqRel);
63 queue.push_back(task);
64
65 let parker_list = self.parker_list.lock().unwrap();
66 for parker in &*parker_list {
67 parker.unpark(false);
68 }
69 }
70 }
71
72 impl CurrentThreadScheduler {
pop(&self) -> Option<Task>73 fn pop(&self) -> Option<Task> {
74 let mut queue = self.inner.lock().unwrap();
75 queue.pop_front()
76 }
77 }
78
79 pub(crate) struct Parker {
80 state: AtomicUsize,
81 mutex: Mutex<bool>,
82 condvar: Condvar,
83 driver: Arc<Mutex<Driver>>,
84 handle: Arc<Handle>,
85 }
86
87 impl Parker {
new(driver: Arc<Mutex<Driver>>, handle: Arc<Handle>) -> Parker88 fn new(driver: Arc<Mutex<Driver>>, handle: Arc<Handle>) -> Parker {
89 Parker {
90 state: AtomicUsize::new(IDLE),
91 mutex: Mutex::new(false),
92 condvar: Condvar::new(),
93 driver,
94 handle,
95 }
96 }
97
park(&self) -> bool98 fn park(&self) -> bool {
99 let (mut park, mut wake) = (true, false);
100 if let Ok(mut driver) = self.driver.try_lock() {
101 (park, wake) = self.park_on_driver(&mut driver);
102 }
103 if park {
104 self.park_on_condvar()
105 } else {
106 wake
107 }
108 }
109
park_on_driver(&self, driver: &mut Driver) -> (bool, bool)110 fn park_on_driver(&self, driver: &mut Driver) -> (bool, bool) {
111 match self
112 .state
113 .compare_exchange(IDLE, PARKED_ON_DRIVER, AcqRel, Acquire)
114 {
115 Ok(_) => {}
116 Err(NOTIFIED_BLOCK) | Err(NOTIFIED) => {
117 return match self.state.swap(IDLE, AcqRel) {
118 // No need to park on condvar, need to awaken the blocked task.
119 NOTIFIED_BLOCK => (false, true),
120 // No need to park on condvar, no need to awaken the blocked task.
121 NOTIFIED => (false, false),
122 actual => panic!("invalid park state when notifying; actual = {actual}"),
123 };
124 }
125 Err(actual) => panic!("inconsistent park state; actual = {actual}"),
126 }
127
128 let park = match driver.run() {
129 ParkFlag::NotPark => false,
130 ParkFlag::Park => true,
131 ParkFlag::ParkTimeout(_) => false,
132 };
133
134 match self.state.swap(IDLE, AcqRel) {
135 NOTIFIED => (false, false),
136 NOTIFIED_BLOCK => (false, true),
137 PARKED_ON_DRIVER => (park, false),
138 n => panic!("inconsistent park_timeout state: {n}"),
139 }
140 }
141
park_on_condvar(&self) -> bool142 fn park_on_condvar(&self) -> bool {
143 let mut lock = self.mutex.lock().unwrap();
144 match self
145 .state
146 .compare_exchange(IDLE, PARKED_ON_CONDVAR, AcqRel, Acquire)
147 {
148 Ok(_) => {}
149 Err(NOTIFIED_BLOCK) | Err(NOTIFIED) => {
150 return match self.state.swap(IDLE, AcqRel) {
151 // Need to awaken the blocked task.
152 NOTIFIED_BLOCK => true,
153 // No need to awaken the blocked task.
154 NOTIFIED => false,
155 actual => panic!("invalid park state when notifying; actual = {actual}"),
156 };
157 }
158 Err(actual) => panic!("inconsistent park state; actual = {actual}"),
159 }
160
161 while !*lock {
162 lock = self.condvar.wait(lock).unwrap();
163 }
164 *lock = false;
165
166 match self.state.swap(IDLE, AcqRel) {
167 NOTIFIED => false,
168 NOTIFIED_BLOCK => true,
169 n => panic!("inconsistent park_timeout state: {n}"),
170 }
171 }
172
unpark(&self, wake: bool)173 fn unpark(&self, wake: bool) {
174 if wake {
175 match self.state.swap(NOTIFIED_BLOCK, AcqRel) {
176 IDLE | NOTIFIED | NOTIFIED_BLOCK => {}
177 PARKED_ON_CONDVAR => {
178 let mut lock = self.mutex.lock().unwrap();
179 *lock = true;
180 mem::drop(lock);
181 self.condvar.notify_one();
182 }
183 PARKED_ON_DRIVER => self.handle.wake(),
184 actual => panic!("inconsistent state in unpark; actual = {actual}"),
185 }
186 } else {
187 match self.state.swap(NOTIFIED, AcqRel) {
188 IDLE | NOTIFIED => {}
189 NOTIFIED_BLOCK => self.unpark(true),
190 PARKED_ON_CONDVAR => {
191 let mut lock = self.mutex.lock().unwrap();
192 *lock = true;
193 mem::drop(lock);
194 self.condvar.notify_one();
195 }
196 PARKED_ON_DRIVER => self.handle.wake(),
197 actual => panic!("inconsistent state in unpark; actual = {actual}"),
198 }
199 }
200 }
201 }
202
waker(parker: Arc<Parker>) -> Waker203 fn waker(parker: Arc<Parker>) -> Waker {
204 let data = Arc::into_raw(parker).cast::<()>();
205 unsafe { Waker::from_raw(RawWaker::new(data, &CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE)) }
206 }
207
208 static CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE: RawWakerVTable =
209 RawWakerVTable::new(clone, wake, wake_by_ref, drop);
210
clone(ptr: *const ()) -> RawWaker211 fn clone(ptr: *const ()) -> RawWaker {
212 let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
213
214 // increment the ref count
215 mem::forget(parker.clone());
216
217 let data = Arc::into_raw(parker).cast::<()>();
218 RawWaker::new(data, &CURRENT_THREAD_RAW_WAKER_VIRTUAL_TABLE)
219 }
220
wake(ptr: *const ())221 fn wake(ptr: *const ()) {
222 let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
223 parker.unpark(true);
224 }
225
wake_by_ref(ptr: *const ())226 fn wake_by_ref(ptr: *const ()) {
227 let parker = unsafe { Arc::from_raw(ptr.cast::<Parker>()) };
228 parker.unpark(true);
229 mem::forget(parker);
230 }
231
drop(ptr: *const ())232 fn drop(ptr: *const ()) {
233 unsafe { mem::drop(Arc::from_raw(ptr.cast::<Parker>())) };
234 }
235
236 impl CurrentThreadSpawner {
new() -> Self237 pub(crate) fn new() -> Self {
238 let (handle, driver) = Driver::initialize();
239 Self {
240 scheduler: Default::default(),
241 driver,
242 handle,
243 }
244 }
245
get_parker(&self) -> Parker246 fn get_parker(&self) -> Parker {
247 Parker::new(self.driver.clone(), self.handle.clone())
248 }
249
spawn<T>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<T::Output> where T: Future + Send + 'static, T::Output: Send + 'static,250 pub(crate) fn spawn<T>(&self, builder: &TaskBuilder, task: T) -> JoinHandle<T::Output>
251 where
252 T: Future + Send + 'static,
253 T::Output: Send + 'static,
254 {
255 let scheduler = Arc::downgrade(&self.scheduler);
256 let (task, handle) = Task::create_task(builder, scheduler, task, VirtualTableType::Ylong);
257
258 let mut queue = self.scheduler.inner.lock().unwrap();
259 queue.push_back(task);
260 #[cfg(feature = "metrics")]
261 self.scheduler.count.fetch_add(1, AcqRel);
262
263 let parker_list = self.scheduler.parker_list.lock().unwrap();
264 for parker in &*parker_list {
265 parker.unpark(false);
266 }
267 handle
268 }
269
block_on<T>(&self, future: T) -> T::Output where T: Future,270 pub(crate) fn block_on<T>(&self, future: T) -> T::Output
271 where
272 T: Future,
273 {
274 let parker = Arc::new(self.get_parker());
275 let mut parker_list = self.scheduler.parker_list.lock().unwrap();
276 parker_list.push(parker.clone());
277 mem::drop(parker_list);
278
279 let waker = waker(parker.clone());
280 let mut cx = Context::from_waker(&waker);
281
282 let mut future = future;
283 let mut future = unsafe { Pin::new_unchecked(&mut future) };
284
285 let mut wake = true;
286
287 loop {
288 if wake {
289 if let Poll::Ready(res) = future.as_mut().poll(&mut cx) {
290 return res;
291 }
292 }
293
294 while let Some(task) = self.scheduler.pop() {
295 task.run();
296 }
297
298 wake = parker.park();
299 }
300 }
301 }
302
303 #[cfg(test)]
304 mod test {
305 macro_rules! cfg_sync {
306 ($($item:item)*) => {
307 $(
308 #[cfg(feature = "sync")]
309 $item
310 )*
311 }
312 }
313
314 use crate::executor::current_thread::CurrentThreadSpawner;
315 use crate::task::{yield_now, TaskBuilder};
316
317 cfg_sync! {
318 use crate::sync::Waiter;
319 use std::sync::atomic::AtomicUsize;
320 use std::sync::atomic::Ordering::{Acquire, Release};
321 use std::sync::{Condvar, Mutex};
322 use std::sync::Arc;
323
324 pub(crate) struct Parker {
325 mutex: Mutex<bool>,
326 condvar: Condvar,
327 }
328
329 impl Parker {
330 fn new() -> Parker {
331 Parker {
332 mutex: Mutex::new(false),
333 condvar: Condvar::new(),
334 }
335 }
336
337 fn notified(&self) {
338 let mut guard = self.mutex.lock().unwrap();
339
340 while !*guard {
341 guard = self.condvar.wait(guard).unwrap();
342 }
343 *guard = false;
344 }
345
346 fn notify_one(&self) {
347 let mut guard = self.mutex.lock().unwrap();
348 *guard = true;
349 drop(guard);
350 self.condvar.notify_one();
351 }
352 }
353 }
354
355 cfg_net! {
356 use std::net::SocketAddr;
357 use crate::net::{TcpListener, TcpStream};
358 use crate::io::{AsyncReadExt, AsyncWriteExt};
359
360 const ADDR: &str = "127.0.0.1:0";
361
362 pub async fn ylong_tcp_server(tx: crate::sync::oneshot::Sender<SocketAddr>) {
363 let tcp = TcpListener::bind(ADDR).await.unwrap();
364 let addr = tcp.local_addr().unwrap();
365 tx.send(addr).unwrap();
366 let (mut stream, _) = tcp.accept().await.unwrap();
367 for _ in 0..3 {
368 let mut buf = [0; 100];
369 stream.read_exact(&mut buf).await.unwrap();
370 assert_eq!(buf, [3; 100]);
371
372 let buf = [2; 100];
373 stream.write_all(&buf).await.unwrap();
374 }
375 }
376
377 pub async fn ylong_tcp_client(rx: crate::sync::oneshot::Receiver<SocketAddr>) {
378 let addr = rx.await.unwrap();
379 let mut tcp = TcpStream::connect(addr).await;
380 while tcp.is_err() {
381 tcp = TcpStream::connect(addr).await;
382 }
383 let mut tcp = tcp.unwrap();
384 for _ in 0..3 {
385 let buf = [3; 100];
386 tcp.write_all(&buf).await.unwrap();
387
388 let mut buf = [0; 100];
389 tcp.read_exact(&mut buf).await.unwrap();
390 assert_eq!(buf, [2; 100]);
391 }
392 }
393 }
394
395 /// UT test cases for `block_on()`.
396 ///
397 /// # Brief
398 /// 1. Spawn two tasks, check the running status of tasks in the queue when
399 /// the yield task is blocked on.
400 #[test]
ut_current_thread_block_on()401 fn ut_current_thread_block_on() {
402 let spawner = CurrentThreadSpawner::new();
403 let handle1 = spawner.spawn(&TaskBuilder::default(), async move { 1 });
404 let handle2 = spawner.spawn(&TaskBuilder::default(), async move { 1 });
405 spawner.block_on(yield_now());
406 assert_eq!(spawner.scheduler.inner.lock().unwrap().len(), 0);
407 assert_eq!(spawner.block_on(handle1).unwrap(), 1);
408 assert_eq!(spawner.block_on(handle2).unwrap(), 1);
409 }
410
411 /// UT test cases for `spawn()` and `block_on()`.
412 ///
413 /// # Brief
414 /// 1. Spawn two tasks before the blocked task running and check the status
415 /// of two tasks.
416 /// 2. Spawn two tasks after the blocked task running and check the status
417 /// of two tasks.
418 #[test]
419 #[cfg(feature = "sync")]
ut_current_thread_run_queue()420 fn ut_current_thread_run_queue() {
421 use crate::builder::RuntimeBuilder;
422 let spawner = Arc::new(RuntimeBuilder::new_current_thread().build().unwrap());
423
424 let finished = Arc::new(AtomicUsize::new(0));
425
426 let finished_clone = finished.clone();
427 let notify1 = Arc::new(Parker::new());
428 let notify1_clone = notify1.clone();
429 spawner.spawn(async move {
430 finished_clone.fetch_add(1, Release);
431 notify1_clone.notify_one();
432 });
433
434 let finished_clone = finished.clone();
435 let notify2 = Arc::new(Parker::new());
436 let notify2_clone = notify2.clone();
437 spawner.spawn(async move {
438 finished_clone.fetch_add(1, Release);
439 notify2_clone.notify_one();
440 });
441
442 let waiter = Arc::new(Waiter::new());
443 let waiter_clone = waiter.clone();
444 let spawner_clone = spawner.clone();
445 let join = std::thread::spawn(move || {
446 spawner_clone.block_on(async move { waiter_clone.wait().await })
447 });
448
449 notify1.notified();
450 notify2.notified();
451 assert_eq!(finished.load(Acquire), 2);
452
453 let finished_clone = finished.clone();
454 let notify1 = Arc::new(Parker::new());
455 let notify1_clone = notify1.clone();
456 spawner.spawn(async move {
457 finished_clone.fetch_add(1, Release);
458 notify1_clone.notify_one();
459 });
460
461 let finished_clone = finished.clone();
462 let notify2 = Arc::new(Parker::new());
463 let notify2_clone = notify2.clone();
464 spawner.spawn(async move {
465 finished_clone.fetch_add(1, Release);
466 notify2_clone.notify_one();
467 });
468
469 notify1.notified();
470 notify2.notified();
471 assert_eq!(finished.load(Acquire), 4);
472
473 waiter.wake_one();
474 join.join().unwrap();
475
476 #[cfg(feature = "net")]
477 crate::executor::worker::CURRENT_WORKER.with(|ctx| {
478 ctx.set(std::ptr::null());
479 });
480 }
481
482 /// UT test cases for io tasks.
483 ///
484 /// # Brief
485 /// 1. Spawns a tcp server to read and write data for three times.
486 /// 2. Spawns a tcp client to read and write data for three times.
487 #[test]
488 #[cfg(feature = "net")]
ut_current_thread_io()489 fn ut_current_thread_io() {
490 use crate::builder::RuntimeBuilder;
491
492 let spawner = RuntimeBuilder::new_current_thread().build().unwrap();
493 let (tx, rx) = crate::sync::oneshot::channel();
494 let join_handle = spawner.spawn(ylong_tcp_server(tx));
495
496 spawner.block_on(ylong_tcp_client(rx));
497 spawner.block_on(join_handle).unwrap();
498
499 let spawner = RuntimeBuilder::new_current_thread().build().unwrap();
500 let (tx, rx) = crate::sync::oneshot::channel();
501 let join_handle = spawner.spawn(ylong_tcp_client(rx));
502 spawner.block_on(ylong_tcp_server(tx));
503 spawner.block_on(join_handle).unwrap();
504 }
505 }
506