1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "sanitizers.h"
25 #include "task.h"
26 #include "thread.h"
27 #include "thread_local.h"
28
29 #include <array>
30 #include <atomic>
31 #include <chrono>
32 #include <condition_variable>
33 #include <functional>
34 #include <thread>
35
36 namespace marl {
37
38 class OSFiber;
39
40 // Scheduler asynchronously processes Tasks.
41 // A scheduler can be bound to one or more threads using the bind() method.
42 // Once bound to a thread, that thread can call marl::schedule() to enqueue
43 // work tasks to be executed asynchronously.
44 // Scheduler are initially constructed in single-threaded mode.
45 // Call setWorkerThreadCount() to spawn dedicated worker threads.
46 class Scheduler {
47 class Worker;
48
49 public:
50 using TimePoint = std::chrono::system_clock::time_point;
51 using Predicate = std::function<bool()>;
52 using ThreadInitializer = std::function<void(int workerId)>;
53
54 // Config holds scheduler configuration settings that can be passed to the
55 // Scheduler constructor.
56 struct Config {
57 static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
58
59 // Per-worker-thread settings.
60 struct WorkerThread {
61 // Total number of dedicated worker threads to spawn for the scheduler.
62 int count = 0;
63
64 // Initializer function to call after thread creation and before any work
65 // is run by the thread.
66 ThreadInitializer initializer;
67
68 // Thread affinity policy to use for worker threads.
69 std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
70 };
71
72 WorkerThread workerThread;
73
74 // Memory allocator to use for the scheduler and internal allocations.
75 Allocator* allocator = Allocator::Default;
76
77 // Size of each fiber stack. This may be rounded up to the nearest
78 // allocation granularity for the given platform.
79 size_t fiberStackSize = DefaultFiberStackSize;
80
81 // allCores() returns a Config with a worker thread for each of the logical
82 // cpus available to the process.
83 MARL_EXPORT
84 static Config allCores();
85
86 // Fluent setters that return this Config so set calls can be chained.
87 MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
88 MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
89 MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
90 MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
91 const ThreadInitializer&);
92 MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
93 const std::shared_ptr<Thread::Affinity::Policy>&);
94 };
95
96 // Constructor.
97 MARL_EXPORT
98 Scheduler(const Config&);
99
100 // Destructor.
101 // Blocks until the scheduler is unbound from all threads before returning.
102 MARL_EXPORT
103 ~Scheduler();
104
105 // get() returns the scheduler bound to the current thread.
106 MARL_EXPORT
107 static Scheduler* get();
108
109 // bind() binds this scheduler to the current thread.
110 // There must be no existing scheduler bound to the thread prior to calling.
111 MARL_EXPORT
112 void bind();
113
114 // unbind() unbinds the scheduler currently bound to the current thread.
115 // There must be an existing scheduler bound to the thread prior to calling.
116 // unbind() flushes any enqueued tasks on the single-threaded worker before
117 // returning.
118 MARL_EXPORT
119 static void unbind();
120
121 // enqueue() queues the task for asynchronous execution.
122 MARL_EXPORT
123 void enqueue(Task&& task);
124
125 // config() returns the Config that was used to build the scheduler.
126 MARL_EXPORT
127 const Config& config() const;
128
129 // Fibers expose methods to perform cooperative multitasking and are
130 // automatically created by the Scheduler.
131 //
132 // The currently executing Fiber can be obtained by calling Fiber::current().
133 //
134 // When execution becomes blocked, yield() can be called to suspend execution
135 // of the fiber and start executing other pending work. Once the block has
136 // been lifted, schedule() can be called to reschedule the Fiber on the same
137 // thread that previously executed it.
138 class Fiber {
139 public:
140 // current() returns the currently executing fiber, or nullptr if called
141 // without a bound scheduler.
142 MARL_EXPORT
143 static Fiber* current();
144
145 // wait() suspends execution of this Fiber until the Fiber is woken up with
146 // a call to notify() and the predicate pred returns true.
147 // If the predicate pred does not return true when notify() is called, then
148 // the Fiber is automatically re-suspended, and will need to be woken with
149 // another call to notify().
150 // While the Fiber is suspended, the scheduler thread may continue executing
151 // other tasks.
152 // lock must be locked before calling, and is unlocked by wait() just before
153 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
154 // will be locked before wait() returns.
155 // pred will be always be called with the lock held.
156 // wait() must only be called on the currently executing fiber.
157 MARL_EXPORT
158 void wait(marl::lock& lock, const Predicate& pred);
159
160 // wait() suspends execution of this Fiber until the Fiber is woken up with
161 // a call to notify() and the predicate pred returns true, or sometime after
162 // the timeout is reached.
163 // If the predicate pred does not return true when notify() is called, then
164 // the Fiber is automatically re-suspended, and will need to be woken with
165 // another call to notify() or will be woken sometime after the timeout is
166 // reached.
167 // While the Fiber is suspended, the scheduler thread may continue executing
168 // other tasks.
169 // lock must be locked before calling, and is unlocked by wait() just before
170 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
171 // will be locked before wait() returns.
172 // pred will be always be called with the lock held.
173 // wait() must only be called on the currently executing fiber.
174 template <typename Clock, typename Duration>
175 MARL_NO_EXPORT inline bool wait(
176 marl::lock& lock,
177 const std::chrono::time_point<Clock, Duration>& timeout,
178 const Predicate& pred);
179
180 // wait() suspends execution of this Fiber until the Fiber is woken up with
181 // a call to notify().
182 // While the Fiber is suspended, the scheduler thread may continue executing
183 // other tasks.
184 // wait() must only be called on the currently executing fiber.
185 //
186 // Warning: Unlike wait() overloads that take a lock and predicate, this
187 // form of wait() offers no safety for notify() signals that occur before
188 // the fiber is suspended, when signalling between different threads. In
189 // this scenario you may deadlock. For this reason, it is only ever
190 // recommended to use this overload if you can guarantee that the calls to
191 // wait() and notify() are made by the same thread.
192 //
193 // Use with extreme caution.
194 MARL_NO_EXPORT inline void wait();
195
196 // wait() suspends execution of this Fiber until the Fiber is woken up with
197 // a call to notify(), or sometime after the timeout is reached.
198 // While the Fiber is suspended, the scheduler thread may continue executing
199 // other tasks.
200 // wait() must only be called on the currently executing fiber.
201 //
202 // Warning: Unlike wait() overloads that take a lock and predicate, this
203 // form of wait() offers no safety for notify() signals that occur before
204 // the fiber is suspended, when signalling between different threads. For
205 // this reason, it is only ever recommended to use this overload if you can
206 // guarantee that the calls to wait() and notify() are made by the same
207 // thread.
208 //
209 // Use with extreme caution.
210 template <typename Clock, typename Duration>
211 MARL_NO_EXPORT inline bool wait(
212 const std::chrono::time_point<Clock, Duration>& timeout);
213
214 // notify() reschedules the suspended Fiber for execution.
215 // notify() is usually only called when the predicate for one or more wait()
216 // calls will likely return true.
217 MARL_EXPORT
218 void notify();
219
220 // id is the thread-unique identifier of the Fiber.
221 uint32_t const id;
222
223 private:
224 friend class Allocator;
225 friend class Scheduler;
226
227 enum class State {
228 // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
229 // ready to be recycled.
230 Idle,
231
232 // Yielded: the Fiber is currently blocked on a wait() call with no
233 // timeout.
234 Yielded,
235
236 // Waiting: the Fiber is currently blocked on a wait() call with a
237 // timeout. The fiber is stilling in the Worker::Work::waiting queue.
238 Waiting,
239
240 // Queued: the Fiber is currently queued for execution in the
241 // Worker::Work::fibers queue.
242 Queued,
243
244 // Running: the Fiber is currently executing.
245 Running,
246 };
247
248 Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
249
250 // switchTo() switches execution to the given fiber.
251 // switchTo() must only be called on the currently executing fiber.
252 void switchTo(Fiber*);
253
254 // create() constructs and returns a new fiber with the given identifier,
255 // stack size and func that will be executed when switched to.
256 static Allocator::unique_ptr<Fiber> create(
257 Allocator* allocator,
258 uint32_t id,
259 size_t stackSize,
260 const std::function<void()>& func);
261
262 // createFromCurrentThread() constructs and returns a new fiber with the
263 // given identifier for the current thread.
264 static Allocator::unique_ptr<Fiber> createFromCurrentThread(
265 Allocator* allocator,
266 uint32_t id);
267
268 // toString() returns a string representation of the given State.
269 // Used for debugging.
270 static const char* toString(State state);
271
272 Allocator::unique_ptr<OSFiber> const impl;
273 Worker* const worker;
274 State state = State::Running; // Guarded by Worker's work.mutex.
275 };
276
277 private:
278 Scheduler(const Scheduler&) = delete;
279 Scheduler(Scheduler&&) = delete;
280 Scheduler& operator=(const Scheduler&) = delete;
281 Scheduler& operator=(Scheduler&&) = delete;
282
283 // Maximum number of worker threads.
284 static constexpr size_t MaxWorkerThreads = 256;
285
286 // WaitingFibers holds all the fibers waiting on a timeout.
287 struct WaitingFibers {
288 inline WaitingFibers(Allocator*);
289
290 // operator bool() returns true iff there are any wait fibers.
291 inline operator bool() const;
292
293 // take() returns the next fiber that has exceeded its timeout, or nullptr
294 // if there are no fibers that have yet exceeded their timeouts.
295 inline Fiber* take(const TimePoint& timeout);
296
297 // next() returns the timepoint of the next fiber to timeout.
298 // next() can only be called if operator bool() returns true.
299 inline TimePoint next() const;
300
301 // add() adds another fiber and timeout to the list of waiting fibers.
302 inline void add(const TimePoint& timeout, Fiber* fiber);
303
304 // erase() removes the fiber from the waiting list.
305 inline void erase(Fiber* fiber);
306
307 // contains() returns true if fiber is waiting.
308 inline bool contains(Fiber* fiber) const;
309
310 private:
311 struct Timeout {
312 TimePoint timepoint;
313 Fiber* fiber;
314 inline bool operator<(const Timeout&) const;
315 };
316 containers::set<Timeout, std::less<Timeout>> timeouts;
317 containers::unordered_map<Fiber*, TimePoint> fibers;
318 };
319
320 // TODO: Implement a queue that recycles elements to reduce number of
321 // heap allocations.
322 using TaskQueue = containers::deque<Task>;
323 using FiberQueue = containers::deque<Fiber*>;
324 using FiberSet = containers::unordered_set<Fiber*>;
325
326 // Workers execute Tasks on a single thread.
327 // Once a task is started, it may yield to other tasks on the same Worker.
328 // Tasks are always resumed by the same Worker.
329 class Worker {
330 public:
331 enum class Mode {
332 // Worker will spawn a background thread to process tasks.
333 MultiThreaded,
334
335 // Worker will execute tasks whenever it yields.
336 SingleThreaded,
337 };
338
339 Worker(Scheduler* scheduler, Mode mode, uint32_t id);
340
341 // start() begins execution of the worker.
342 void start() EXCLUDES(work.mutex);
343
344 // stop() ceases execution of the worker, blocking until all pending
345 // tasks have fully finished.
346 void stop() EXCLUDES(work.mutex);
347
348 // wait() suspends execution of the current task until the predicate pred
349 // returns true or the optional timeout is reached.
350 // See Fiber::wait() for more information.
351 MARL_EXPORT
352 bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
353 EXCLUDES(work.mutex);
354
355 // wait() suspends execution of the current task until the fiber is
356 // notified, or the optional timeout is reached.
357 // See Fiber::wait() for more information.
358 MARL_EXPORT
359 bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
360
361 // suspend() suspends the currently executing Fiber until the fiber is
362 // woken with a call to enqueue(Fiber*), or automatically sometime after the
363 // optional timeout.
364 void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
365
366 // enqueue(Fiber*) enqueues resuming of a suspended fiber.
367 void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
368
369 // enqueue(Task&&) enqueues a new, unstarted task.
370 void enqueue(Task&& task) EXCLUDES(work.mutex);
371
372 // tryLock() attempts to lock the worker for task enqueuing.
373 // If the lock was successful then true is returned, and the caller must
374 // call enqueueAndUnlock().
375 bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
376
377 // enqueueAndUnlock() enqueues the task and unlocks the worker.
378 // Must only be called after a call to tryLock() which returned true.
379 // _Releases_lock_(work.mutex)
380 void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
381
382 // runUntilShutdown() processes all tasks and fibers until there are no more
383 // and shutdown is true, upon runUntilShutdown() returns.
384 void runUntilShutdown() REQUIRES(work.mutex);
385
386 // steal() attempts to steal a Task from the worker for another worker.
387 // Returns true if a task was taken and assigned to out, otherwise false.
388 bool steal(Task& out) EXCLUDES(work.mutex);
389
390 // getCurrent() returns the Worker currently bound to the current
391 // thread.
392 static inline Worker* getCurrent();
393
394 // getCurrentFiber() returns the Fiber currently being executed.
395 inline Fiber* getCurrentFiber() const;
396
397 // Unique identifier of the Worker.
398 const uint32_t id;
399
400 private:
401 // run() is the task processing function for the worker.
402 // run() processes tasks until stop() is called.
403 void run() REQUIRES(work.mutex);
404
405 // createWorkerFiber() creates a new fiber that when executed calls
406 // run().
407 Fiber* createWorkerFiber() REQUIRES(work.mutex);
408
409 // switchToFiber() switches execution to the given fiber. The fiber
410 // must belong to this worker.
411 void switchToFiber(Fiber*) REQUIRES(work.mutex);
412
413 // runUntilIdle() executes all pending tasks and then returns.
414 void runUntilIdle() REQUIRES(work.mutex);
415
416 // waitForWork() blocks until new work is available, potentially calling
417 // spinForWork().
418 void waitForWork() REQUIRES(work.mutex);
419
420 // spinForWork() attempts to steal work from another Worker, and keeps
421 // the thread awake for a short duration. This reduces overheads of
422 // frequently putting the thread to sleep and re-waking.
423 void spinForWork();
424
425 // enqueueFiberTimeouts() enqueues all the fibers that have finished
426 // waiting.
427 void enqueueFiberTimeouts() REQUIRES(work.mutex);
428
429 inline void changeFiberState(Fiber* fiber,
430 Fiber::State from,
431 Fiber::State to) const REQUIRES(work.mutex);
432
433 inline void setFiberState(Fiber* fiber, Fiber::State to) const
434 REQUIRES(work.mutex);
435
436 // Work holds tasks and fibers that are enqueued on the Worker.
437 struct Work {
438 inline Work(Allocator*);
439
440 std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
441 GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
442 GUARDED_BY(mutex) TaskQueue tasks;
443 GUARDED_BY(mutex) FiberQueue fibers;
444 GUARDED_BY(mutex) WaitingFibers waiting;
445 GUARDED_BY(mutex) bool notifyAdded = true;
446 std::condition_variable added;
447 marl::mutex mutex;
448
449 template <typename F>
450 inline void wait(F&&) REQUIRES(mutex);
451 };
452
453 // https://en.wikipedia.org/wiki/Xorshift
454 class FastRnd {
455 public:
operator()456 inline uint64_t operator()() {
457 x ^= x << 13;
458 x ^= x >> 7;
459 x ^= x << 17;
460 return x;
461 }
462
463 private:
464 uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
465 };
466
467 // The current worker bound to the current thread.
468 MARL_DECLARE_THREAD_LOCAL(Worker*, current);
469
470 Mode const mode;
471 Scheduler* const scheduler;
472 Allocator::unique_ptr<Fiber> mainFiber;
473 Fiber* currentFiber = nullptr;
474 Thread thread;
475 Work work;
476 FiberSet idleFibers; // Fibers that have completed which can be reused.
477 containers::vector<Allocator::unique_ptr<Fiber>, 16>
478 workerFibers; // All fibers created by this worker.
479 FastRnd rng;
480 bool shutdown = false;
481 };
482
483 // stealWork() attempts to steal a task from the worker with the given id.
484 // Returns true if a task was stolen and assigned to out, otherwise false.
485 bool stealWork(Worker* thief, uint64_t from, Task& out);
486
487 // onBeginSpinning() is called when a Worker calls spinForWork().
488 // The scheduler will prioritize this worker for new tasks to try to prevent
489 // it going to sleep.
490 void onBeginSpinning(int workerId);
491
492 // setBound() sets the scheduler bound to the current thread.
493 static void setBound(Scheduler* scheduler);
494
495 // The scheduler currently bound to the current thread.
496 MARL_DECLARE_THREAD_LOCAL(Scheduler*, bound);
497
498 // The immutable configuration used to build the scheduler.
499 const Config cfg;
500
501 std::array<std::atomic<int>, 8> spinningWorkers;
502 std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
503
504 std::atomic<unsigned int> nextEnqueueIndex = {0};
505 std::array<Worker*, MaxWorkerThreads> workerThreads;
506
507 struct SingleThreadedWorkers {
508 inline SingleThreadedWorkers(Allocator*);
509
510 using WorkerByTid =
511 containers::unordered_map<std::thread::id,
512 Allocator::unique_ptr<Worker>>;
513 marl::mutex mutex;
514 GUARDED_BY(mutex) std::condition_variable unbind;
515 GUARDED_BY(mutex) WorkerByTid byTid;
516 };
517 SingleThreadedWorkers singleThreadedWorkers;
518 };
519
520 ////////////////////////////////////////////////////////////////////////////////
521 // Scheduler::Config
522 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)523 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
524 allocator = alloc;
525 return *this;
526 }
527
setFiberStackSize(size_t size)528 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
529 fiberStackSize = size;
530 return *this;
531 }
532
setWorkerThreadCount(int count)533 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
534 workerThread.count = count;
535 return *this;
536 }
537
setWorkerThreadInitializer(const ThreadInitializer & initializer)538 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
539 const ThreadInitializer& initializer) {
540 workerThread.initializer = initializer;
541 return *this;
542 }
543
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)544 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
545 const std::shared_ptr<Thread::Affinity::Policy>& policy) {
546 workerThread.affinityPolicy = policy;
547 return *this;
548 }
549
550 ////////////////////////////////////////////////////////////////////////////////
551 // Scheduler::Fiber
552 ////////////////////////////////////////////////////////////////////////////////
553 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)554 bool Scheduler::Fiber::wait(
555 marl::lock& lock,
556 const std::chrono::time_point<Clock, Duration>& timeout,
557 const Predicate& pred) {
558 using ToDuration = typename TimePoint::duration;
559 using ToClock = typename TimePoint::clock;
560 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
561 return worker->wait(lock, &tp, pred);
562 }
563
wait()564 void Scheduler::Fiber::wait() {
565 worker->wait(nullptr);
566 }
567
568 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)569 bool Scheduler::Fiber::wait(
570 const std::chrono::time_point<Clock, Duration>& timeout) {
571 using ToDuration = typename TimePoint::duration;
572 using ToClock = typename TimePoint::clock;
573 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
574 return worker->wait(&tp);
575 }
576
getCurrent()577 Scheduler::Worker* Scheduler::Worker::getCurrent() {
578 return Worker::current;
579 }
580
getCurrentFiber()581 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
582 return currentFiber;
583 }
584
585 // schedule() schedules the task T to be asynchronously called using the
586 // currently bound scheduler.
schedule(Task && t)587 inline void schedule(Task&& t) {
588 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
589 auto scheduler = Scheduler::get();
590 scheduler->enqueue(std::move(t));
591 }
592
593 // schedule() schedules the function f to be asynchronously called with the
594 // given arguments using the currently bound scheduler.
595 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)596 inline void schedule(Function&& f, Args&&... args) {
597 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
598 auto scheduler = Scheduler::get();
599 scheduler->enqueue(
600 Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
601 }
602
603 // schedule() schedules the function f to be asynchronously called using the
604 // currently bound scheduler.
605 template <typename Function>
schedule(Function && f)606 inline void schedule(Function&& f) {
607 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
608 auto scheduler = Scheduler::get();
609 scheduler->enqueue(Task(std::forward<Function>(f)));
610 }
611
612 } // namespace marl
613
614 #endif // marl_scheduler_h
615