1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "task.h"
25 #include "thread.h"
26
27 #include <array>
28 #include <atomic>
29 #include <chrono>
30 #include <condition_variable>
31 #include <functional>
32 #include <thread>
33
34 namespace marl {
35
36 class OSFiber;
37
38 // Scheduler asynchronously processes Tasks.
39 // A scheduler can be bound to one or more threads using the bind() method.
40 // Once bound to a thread, that thread can call marl::schedule() to enqueue
41 // work tasks to be executed asynchronously.
42 // Scheduler are initially constructed in single-threaded mode.
43 // Call setWorkerThreadCount() to spawn dedicated worker threads.
44 class Scheduler {
45 class Worker;
46
47 public:
48 using TimePoint = std::chrono::system_clock::time_point;
49 using Predicate = std::function<bool()>;
50 using ThreadInitializer = std::function<void(int workerId)>;
51
52 // Config holds scheduler configuration settings that can be passed to the
53 // Scheduler constructor.
54 struct Config {
55 static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
56
57 // Per-worker-thread settings.
58 struct WorkerThread {
59 // Total number of dedicated worker threads to spawn for the scheduler.
60 int count = 0;
61
62 // Initializer function to call after thread creation and before any work
63 // is run by the thread.
64 ThreadInitializer initializer;
65
66 // Thread affinity policy to use for worker threads.
67 std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
68 };
69
70 WorkerThread workerThread;
71
72 // Memory allocator to use for the scheduler and internal allocations.
73 Allocator* allocator = Allocator::Default;
74
75 // Size of each fiber stack. This may be rounded up to the nearest
76 // allocation granularity for the given platform.
77 size_t fiberStackSize = DefaultFiberStackSize;
78
79 // allCores() returns a Config with a worker thread for each of the logical
80 // cpus available to the process.
81 MARL_EXPORT
82 static Config allCores();
83
84 // Fluent setters that return this Config so set calls can be chained.
85 MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
86 MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
87 MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
88 MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
89 const ThreadInitializer&);
90 MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
91 const std::shared_ptr<Thread::Affinity::Policy>&);
92 };
93
94 // Constructor.
95 MARL_EXPORT
96 Scheduler(const Config&);
97
98 // Destructor.
99 // Blocks until the scheduler is unbound from all threads before returning.
100 MARL_EXPORT
101 ~Scheduler();
102
103 // get() returns the scheduler bound to the current thread.
104 MARL_EXPORT
105 static Scheduler* get();
106
107 // bind() binds this scheduler to the current thread.
108 // There must be no existing scheduler bound to the thread prior to calling.
109 MARL_EXPORT
110 void bind();
111
112 // unbind() unbinds the scheduler currently bound to the current thread.
113 // There must be a existing scheduler bound to the thread prior to calling.
114 // unbind() flushes any enqueued tasks on the single-threaded worker before
115 // returning.
116 MARL_EXPORT
117 static void unbind();
118
119 // enqueue() queues the task for asynchronous execution.
120 MARL_EXPORT
121 void enqueue(Task&& task);
122
123 // config() returns the Config that was used to build the schededuler.
124 MARL_EXPORT
125 const Config& config() const;
126
127 // Fibers expose methods to perform cooperative multitasking and are
128 // automatically created by the Scheduler.
129 //
130 // The currently executing Fiber can be obtained by calling Fiber::current().
131 //
132 // When execution becomes blocked, yield() can be called to suspend execution
133 // of the fiber and start executing other pending work. Once the block has
134 // been lifted, schedule() can be called to reschedule the Fiber on the same
135 // thread that previously executed it.
136 class Fiber {
137 public:
138 // current() returns the currently executing fiber, or nullptr if called
139 // without a bound scheduler.
140 MARL_EXPORT
141 static Fiber* current();
142
143 // wait() suspends execution of this Fiber until the Fiber is woken up with
144 // a call to notify() and the predicate pred returns true.
145 // If the predicate pred does not return true when notify() is called, then
146 // the Fiber is automatically re-suspended, and will need to be woken with
147 // another call to notify().
148 // While the Fiber is suspended, the scheduler thread may continue executing
149 // other tasks.
150 // lock must be locked before calling, and is unlocked by wait() just before
151 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
152 // will be locked before wait() returns.
153 // pred will be always be called with the lock held.
154 // wait() must only be called on the currently executing fiber.
155 MARL_EXPORT
156 void wait(marl::lock& lock, const Predicate& pred);
157
158 // wait() suspends execution of this Fiber until the Fiber is woken up with
159 // a call to notify() and the predicate pred returns true, or sometime after
160 // the timeout is reached.
161 // If the predicate pred does not return true when notify() is called, then
162 // the Fiber is automatically re-suspended, and will need to be woken with
163 // another call to notify() or will be woken sometime after the timeout is
164 // reached.
165 // While the Fiber is suspended, the scheduler thread may continue executing
166 // other tasks.
167 // lock must be locked before calling, and is unlocked by wait() just before
168 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
169 // will be locked before wait() returns.
170 // pred will be always be called with the lock held.
171 // wait() must only be called on the currently executing fiber.
172 template <typename Clock, typename Duration>
173 MARL_NO_EXPORT inline bool wait(
174 marl::lock& lock,
175 const std::chrono::time_point<Clock, Duration>& timeout,
176 const Predicate& pred);
177
178 // wait() suspends execution of this Fiber until the Fiber is woken up with
179 // a call to notify().
180 // While the Fiber is suspended, the scheduler thread may continue executing
181 // other tasks.
182 // wait() must only be called on the currently executing fiber.
183 //
184 // Warning: Unlike wait() overloads that take a lock and predicate, this
185 // form of wait() offers no safety for notify() signals that occur before
186 // the fiber is suspended, when signalling between different threads. In
187 // this scenario you may deadlock. For this reason, it is only ever
188 // recommended to use this overload if you can guarantee that the calls to
189 // wait() and notify() are made by the same thread.
190 //
191 // Use with extreme caution.
192 MARL_NO_EXPORT inline void wait();
193
194 // wait() suspends execution of this Fiber until the Fiber is woken up with
195 // a call to notify(), or sometime after the timeout is reached.
196 // While the Fiber is suspended, the scheduler thread may continue executing
197 // other tasks.
198 // wait() must only be called on the currently executing fiber.
199 //
200 // Warning: Unlike wait() overloads that take a lock and predicate, this
201 // form of wait() offers no safety for notify() signals that occur before
202 // the fiber is suspended, when signalling between different threads. For
203 // this reason, it is only ever recommended to use this overload if you can
204 // guarantee that the calls to wait() and notify() are made by the same
205 // thread.
206 //
207 // Use with extreme caution.
208 template <typename Clock, typename Duration>
209 MARL_NO_EXPORT inline bool wait(
210 const std::chrono::time_point<Clock, Duration>& timeout);
211
212 // notify() reschedules the suspended Fiber for execution.
213 // notify() is usually only called when the predicate for one or more wait()
214 // calls will likely return true.
215 MARL_EXPORT
216 void notify();
217
218 // id is the thread-unique identifier of the Fiber.
219 uint32_t const id;
220
221 private:
222 friend class Allocator;
223 friend class Scheduler;
224
225 enum class State {
226 // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
227 // ready to be recycled.
228 Idle,
229
230 // Yielded: the Fiber is currently blocked on a wait() call with no
231 // timeout.
232 Yielded,
233
234 // Waiting: the Fiber is currently blocked on a wait() call with a
235 // timeout. The fiber is stilling in the Worker::Work::waiting queue.
236 Waiting,
237
238 // Queued: the Fiber is currently queued for execution in the
239 // Worker::Work::fibers queue.
240 Queued,
241
242 // Running: the Fiber is currently executing.
243 Running,
244 };
245
246 Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
247
248 // switchTo() switches execution to the given fiber.
249 // switchTo() must only be called on the currently executing fiber.
250 void switchTo(Fiber*);
251
252 // create() constructs and returns a new fiber with the given identifier,
253 // stack size that will executed func when switched to.
254 static Allocator::unique_ptr<Fiber> create(
255 Allocator* allocator,
256 uint32_t id,
257 size_t stackSize,
258 const std::function<void()>& func);
259
260 // createFromCurrentThread() constructs and returns a new fiber with the
261 // given identifier for the current thread.
262 static Allocator::unique_ptr<Fiber> createFromCurrentThread(
263 Allocator* allocator,
264 uint32_t id);
265
266 // toString() returns a string representation of the given State.
267 // Used for debugging.
268 static const char* toString(State state);
269
270 Allocator::unique_ptr<OSFiber> const impl;
271 Worker* const worker;
272 State state = State::Running; // Guarded by Worker's work.mutex.
273 };
274
275 private:
276 Scheduler(const Scheduler&) = delete;
277 Scheduler(Scheduler&&) = delete;
278 Scheduler& operator=(const Scheduler&) = delete;
279 Scheduler& operator=(Scheduler&&) = delete;
280
281 // Maximum number of worker threads.
282 static constexpr size_t MaxWorkerThreads = 256;
283
284 // WaitingFibers holds all the fibers waiting on a timeout.
285 struct WaitingFibers {
286 inline WaitingFibers(Allocator*);
287
288 // operator bool() returns true iff there are any wait fibers.
289 inline operator bool() const;
290
291 // take() returns the next fiber that has exceeded its timeout, or nullptr
292 // if there are no fibers that have yet exceeded their timeouts.
293 inline Fiber* take(const TimePoint& timeout);
294
295 // next() returns the timepoint of the next fiber to timeout.
296 // next() can only be called if operator bool() returns true.
297 inline TimePoint next() const;
298
299 // add() adds another fiber and timeout to the list of waiting fibers.
300 inline void add(const TimePoint& timeout, Fiber* fiber);
301
302 // erase() removes the fiber from the waiting list.
303 inline void erase(Fiber* fiber);
304
305 // contains() returns true if fiber is waiting.
306 inline bool contains(Fiber* fiber) const;
307
308 private:
309 struct Timeout {
310 TimePoint timepoint;
311 Fiber* fiber;
312 inline bool operator<(const Timeout&) const;
313 };
314 containers::set<Timeout, std::less<Timeout>> timeouts;
315 containers::unordered_map<Fiber*, TimePoint> fibers;
316 };
317
318 // TODO: Implement a queue that recycles elements to reduce number of
319 // heap allocations.
320 using TaskQueue = containers::deque<Task>;
321 using FiberQueue = containers::deque<Fiber*>;
322 using FiberSet = containers::unordered_set<Fiber*>;
323
324 // Workers executes Tasks on a single thread.
325 // Once a task is started, it may yield to other tasks on the same Worker.
326 // Tasks are always resumed by the same Worker.
327 class Worker {
328 public:
329 enum class Mode {
330 // Worker will spawn a background thread to process tasks.
331 MultiThreaded,
332
333 // Worker will execute tasks whenever it yields.
334 SingleThreaded,
335 };
336
337 Worker(Scheduler* scheduler, Mode mode, uint32_t id);
338
339 // start() begins execution of the worker.
340 void start() EXCLUDES(work.mutex);
341
342 // stop() ceases execution of the worker, blocking until all pending
343 // tasks have fully finished.
344 void stop() EXCLUDES(work.mutex);
345
346 // wait() suspends execution of the current task until the predicate pred
347 // returns true or the optional timeout is reached.
348 // See Fiber::wait() for more information.
349 MARL_EXPORT
350 bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
351 EXCLUDES(work.mutex);
352
353 // wait() suspends execution of the current task until the fiber is
354 // notified, or the optional timeout is reached.
355 // See Fiber::wait() for more information.
356 MARL_EXPORT
357 bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
358
359 // suspend() suspends the currenetly executing Fiber until the fiber is
360 // woken with a call to enqueue(Fiber*), or automatically sometime after the
361 // optional timeout.
362 void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
363
364 // enqueue(Fiber*) enqueues resuming of a suspended fiber.
365 void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
366
367 // enqueue(Task&&) enqueues a new, unstarted task.
368 void enqueue(Task&& task) EXCLUDES(work.mutex);
369
370 // tryLock() attempts to lock the worker for task enqueing.
371 // If the lock was successful then true is returned, and the caller must
372 // call enqueueAndUnlock().
373 bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
374
375 // enqueueAndUnlock() enqueues the task and unlocks the worker.
376 // Must only be called after a call to tryLock() which returned true.
377 // _Releases_lock_(work.mutex)
378 void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
379
380 // runUntilShutdown() processes all tasks and fibers until there are no more
381 // and shutdown is true, upon runUntilShutdown() returns.
382 void runUntilShutdown() REQUIRES(work.mutex);
383
384 // steal() attempts to steal a Task from the worker for another worker.
385 // Returns true if a task was taken and assigned to out, otherwise false.
386 bool steal(Task& out) EXCLUDES(work.mutex);
387
388 // getCurrent() returns the Worker currently bound to the current
389 // thread.
390 static inline Worker* getCurrent();
391
392 // getCurrentFiber() returns the Fiber currently being executed.
393 inline Fiber* getCurrentFiber() const;
394
395 // Unique identifier of the Worker.
396 const uint32_t id;
397
398 private:
399 // run() is the task processing function for the worker.
400 // run() processes tasks until stop() is called.
401 void run() REQUIRES(work.mutex);
402
403 // createWorkerFiber() creates a new fiber that when executed calls
404 // run().
405 Fiber* createWorkerFiber() REQUIRES(work.mutex);
406
407 // switchToFiber() switches execution to the given fiber. The fiber
408 // must belong to this worker.
409 void switchToFiber(Fiber*) REQUIRES(work.mutex);
410
411 // runUntilIdle() executes all pending tasks and then returns.
412 void runUntilIdle() REQUIRES(work.mutex);
413
414 // waitForWork() blocks until new work is available, potentially calling
415 // spinForWork().
416 void waitForWork() REQUIRES(work.mutex);
417
418 // spinForWork() attempts to steal work from another Worker, and keeps
419 // the thread awake for a short duration. This reduces overheads of
420 // frequently putting the thread to sleep and re-waking.
421 void spinForWork();
422
423 // enqueueFiberTimeouts() enqueues all the fibers that have finished
424 // waiting.
425 void enqueueFiberTimeouts() REQUIRES(work.mutex);
426
427 inline void changeFiberState(Fiber* fiber,
428 Fiber::State from,
429 Fiber::State to) const REQUIRES(work.mutex);
430
431 inline void setFiberState(Fiber* fiber, Fiber::State to) const
432 REQUIRES(work.mutex);
433
434 // Work holds tasks and fibers that are enqueued on the Worker.
435 struct Work {
436 inline Work(Allocator*);
437
438 std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
439 GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
440 GUARDED_BY(mutex) TaskQueue tasks;
441 GUARDED_BY(mutex) FiberQueue fibers;
442 GUARDED_BY(mutex) WaitingFibers waiting;
443 GUARDED_BY(mutex) bool notifyAdded = true;
444 std::condition_variable added;
445 marl::mutex mutex;
446
447 template <typename F>
448 inline void wait(F&&) REQUIRES(mutex);
449 };
450
451 // https://en.wikipedia.org/wiki/Xorshift
452 class FastRnd {
453 public:
operator()454 inline uint64_t operator()() {
455 x ^= x << 13;
456 x ^= x >> 7;
457 x ^= x << 17;
458 return x;
459 }
460
461 private:
462 uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
463 };
464
465 // The current worker bound to the current thread.
466 static thread_local Worker* current;
467
468 Mode const mode;
469 Scheduler* const scheduler;
470 Allocator::unique_ptr<Fiber> mainFiber;
471 Fiber* currentFiber = nullptr;
472 Thread thread;
473 Work work;
474 FiberSet idleFibers; // Fibers that have completed which can be reused.
475 containers::vector<Allocator::unique_ptr<Fiber>, 16>
476 workerFibers; // All fibers created by this worker.
477 FastRnd rng;
478 bool shutdown = false;
479 };
480
481 // stealWork() attempts to steal a task from the worker with the given id.
482 // Returns true if a task was stolen and assigned to out, otherwise false.
483 bool stealWork(Worker* thief, uint64_t from, Task& out);
484
485 // onBeginSpinning() is called when a Worker calls spinForWork().
486 // The scheduler will prioritize this worker for new tasks to try to prevent
487 // it going to sleep.
488 void onBeginSpinning(int workerId);
489
490 // The scheduler currently bound to the current thread.
491 static thread_local Scheduler* bound;
492
493 // The immutable configuration used to build the scheduler.
494 const Config cfg;
495
496 std::array<std::atomic<int>, 8> spinningWorkers;
497 std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
498
499 std::atomic<unsigned int> nextEnqueueIndex = {0};
500 std::array<Worker*, MaxWorkerThreads> workerThreads;
501
502 struct SingleThreadedWorkers {
503 inline SingleThreadedWorkers(Allocator*);
504
505 using WorkerByTid =
506 containers::unordered_map<std::thread::id,
507 Allocator::unique_ptr<Worker>>;
508 marl::mutex mutex;
509 GUARDED_BY(mutex) std::condition_variable unbind;
510 GUARDED_BY(mutex) WorkerByTid byTid;
511 };
512 SingleThreadedWorkers singleThreadedWorkers;
513 };
514
515 ////////////////////////////////////////////////////////////////////////////////
516 // Scheduler::Config
517 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)518 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
519 allocator = alloc;
520 return *this;
521 }
522
setFiberStackSize(size_t size)523 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
524 fiberStackSize = size;
525 return *this;
526 }
527
setWorkerThreadCount(int count)528 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
529 workerThread.count = count;
530 return *this;
531 }
532
setWorkerThreadInitializer(const ThreadInitializer & initializer)533 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
534 const ThreadInitializer& initializer) {
535 workerThread.initializer = initializer;
536 return *this;
537 }
538
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)539 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
540 const std::shared_ptr<Thread::Affinity::Policy>& policy) {
541 workerThread.affinityPolicy = policy;
542 return *this;
543 }
544
545 ////////////////////////////////////////////////////////////////////////////////
546 // Scheduler::Fiber
547 ////////////////////////////////////////////////////////////////////////////////
548 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)549 bool Scheduler::Fiber::wait(
550 marl::lock& lock,
551 const std::chrono::time_point<Clock, Duration>& timeout,
552 const Predicate& pred) {
553 using ToDuration = typename TimePoint::duration;
554 using ToClock = typename TimePoint::clock;
555 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
556 return worker->wait(lock, &tp, pred);
557 }
558
wait()559 void Scheduler::Fiber::wait() {
560 worker->wait(nullptr);
561 }
562
563 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)564 bool Scheduler::Fiber::wait(
565 const std::chrono::time_point<Clock, Duration>& timeout) {
566 using ToDuration = typename TimePoint::duration;
567 using ToClock = typename TimePoint::clock;
568 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
569 return worker->wait(&tp);
570 }
571
getCurrent()572 Scheduler::Worker* Scheduler::Worker::getCurrent() {
573 return Worker::current;
574 }
575
getCurrentFiber()576 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
577 return currentFiber;
578 }
579
580 // schedule() schedules the task T to be asynchronously called using the
581 // currently bound scheduler.
schedule(Task && t)582 inline void schedule(Task&& t) {
583 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
584 auto scheduler = Scheduler::get();
585 scheduler->enqueue(std::move(t));
586 }
587
588 // schedule() schedules the function f to be asynchronously called with the
589 // given arguments using the currently bound scheduler.
590 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)591 inline void schedule(Function&& f, Args&&... args) {
592 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
593 auto scheduler = Scheduler::get();
594 scheduler->enqueue(
595 Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
596 }
597
598 // schedule() schedules the function f to be asynchronously called using the
599 // currently bound scheduler.
600 template <typename Function>
schedule(Function && f)601 inline void schedule(Function&& f) {
602 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
603 auto scheduler = Scheduler::get();
604 scheduler->enqueue(Task(std::forward<Function>(f)));
605 }
606
607 } // namespace marl
608
609 #endif // marl_scheduler_h
610