1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17 
18 #include "containers.h"
19 #include "debug.h"
20 #include "deprecated.h"
21 #include "export.h"
22 #include "memory.h"
23 #include "mutex.h"
24 #include "sanitizers.h"
25 #include "task.h"
26 #include "thread.h"
27 #include "thread_local.h"
28 
29 #include <array>
30 #include <atomic>
31 #include <chrono>
32 #include <condition_variable>
33 #include <functional>
34 #include <thread>
35 
36 namespace marl {
37 
38 class OSFiber;
39 
40 // Scheduler asynchronously processes Tasks.
41 // A scheduler can be bound to one or more threads using the bind() method.
42 // Once bound to a thread, that thread can call marl::schedule() to enqueue
43 // work tasks to be executed asynchronously.
44 // Scheduler are initially constructed in single-threaded mode.
45 // Call setWorkerThreadCount() to spawn dedicated worker threads.
46 class Scheduler {
47   class Worker;
48 
49  public:
50   using TimePoint = std::chrono::system_clock::time_point;
51   using Predicate = std::function<bool()>;
52   using ThreadInitializer = std::function<void(int workerId)>;
53 
54   // Config holds scheduler configuration settings that can be passed to the
55   // Scheduler constructor.
56   struct Config {
57     static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
58 
59     // Per-worker-thread settings.
60     struct WorkerThread {
61       // Total number of dedicated worker threads to spawn for the scheduler.
62       int count = 0;
63 
64       // Initializer function to call after thread creation and before any work
65       // is run by the thread.
66       ThreadInitializer initializer;
67 
68       // Thread affinity policy to use for worker threads.
69       std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
70     };
71 
72     WorkerThread workerThread;
73 
74     // Memory allocator to use for the scheduler and internal allocations.
75     Allocator* allocator = Allocator::Default;
76 
77     // Size of each fiber stack. This may be rounded up to the nearest
78     // allocation granularity for the given platform.
79     size_t fiberStackSize = DefaultFiberStackSize;
80 
81     // allCores() returns a Config with a worker thread for each of the logical
82     // cpus available to the process.
83     MARL_EXPORT
84     static Config allCores();
85 
86     // Fluent setters that return this Config so set calls can be chained.
87     MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
88     MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
89     MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
90     MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
91         const ThreadInitializer&);
92     MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
93         const std::shared_ptr<Thread::Affinity::Policy>&);
94   };
95 
96   // Constructor.
97   MARL_EXPORT
98   Scheduler(const Config&);
99 
100   // Destructor.
101   // Blocks until the scheduler is unbound from all threads before returning.
102   MARL_EXPORT
103   ~Scheduler();
104 
105   // get() returns the scheduler bound to the current thread.
106   MARL_EXPORT
107   static Scheduler* get();
108 
109   // bind() binds this scheduler to the current thread.
110   // There must be no existing scheduler bound to the thread prior to calling.
111   MARL_EXPORT
112   void bind();
113 
114   // unbind() unbinds the scheduler currently bound to the current thread.
115   // There must be an existing scheduler bound to the thread prior to calling.
116   // unbind() flushes any enqueued tasks on the single-threaded worker before
117   // returning.
118   MARL_EXPORT
119   static void unbind();
120 
121   // enqueue() queues the task for asynchronous execution.
122   MARL_EXPORT
123   void enqueue(Task&& task);
124 
125   // config() returns the Config that was used to build the scheduler.
126   MARL_EXPORT
127   const Config& config() const;
128 
129   // Fibers expose methods to perform cooperative multitasking and are
130   // automatically created by the Scheduler.
131   //
132   // The currently executing Fiber can be obtained by calling Fiber::current().
133   //
134   // When execution becomes blocked, yield() can be called to suspend execution
135   // of the fiber and start executing other pending work. Once the block has
136   // been lifted, schedule() can be called to reschedule the Fiber on the same
137   // thread that previously executed it.
138   class Fiber {
139    public:
140     // current() returns the currently executing fiber, or nullptr if called
141     // without a bound scheduler.
142     MARL_EXPORT
143     static Fiber* current();
144 
145     // wait() suspends execution of this Fiber until the Fiber is woken up with
146     // a call to notify() and the predicate pred returns true.
147     // If the predicate pred does not return true when notify() is called, then
148     // the Fiber is automatically re-suspended, and will need to be woken with
149     // another call to notify().
150     // While the Fiber is suspended, the scheduler thread may continue executing
151     // other tasks.
152     // lock must be locked before calling, and is unlocked by wait() just before
153     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
154     // will be locked before wait() returns.
155     // pred will be always be called with the lock held.
156     // wait() must only be called on the currently executing fiber.
157     MARL_EXPORT
158     void wait(marl::lock& lock, const Predicate& pred);
159 
160     // wait() suspends execution of this Fiber until the Fiber is woken up with
161     // a call to notify() and the predicate pred returns true, or sometime after
162     // the timeout is reached.
163     // If the predicate pred does not return true when notify() is called, then
164     // the Fiber is automatically re-suspended, and will need to be woken with
165     // another call to notify() or will be woken sometime after the timeout is
166     // reached.
167     // While the Fiber is suspended, the scheduler thread may continue executing
168     // other tasks.
169     // lock must be locked before calling, and is unlocked by wait() just before
170     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
171     // will be locked before wait() returns.
172     // pred will be always be called with the lock held.
173     // wait() must only be called on the currently executing fiber.
174     template <typename Clock, typename Duration>
175     MARL_NO_EXPORT inline bool wait(
176         marl::lock& lock,
177         const std::chrono::time_point<Clock, Duration>& timeout,
178         const Predicate& pred);
179 
180     // wait() suspends execution of this Fiber until the Fiber is woken up with
181     // a call to notify().
182     // While the Fiber is suspended, the scheduler thread may continue executing
183     // other tasks.
184     // wait() must only be called on the currently executing fiber.
185     //
186     // Warning: Unlike wait() overloads that take a lock and predicate, this
187     // form of wait() offers no safety for notify() signals that occur before
188     // the fiber is suspended, when signalling between different threads. In
189     // this scenario you may deadlock. For this reason, it is only ever
190     // recommended to use this overload if you can guarantee that the calls to
191     // wait() and notify() are made by the same thread.
192     //
193     // Use with extreme caution.
194     MARL_NO_EXPORT inline void wait();
195 
196     // wait() suspends execution of this Fiber until the Fiber is woken up with
197     // a call to notify(), or sometime after the timeout is reached.
198     // While the Fiber is suspended, the scheduler thread may continue executing
199     // other tasks.
200     // wait() must only be called on the currently executing fiber.
201     //
202     // Warning: Unlike wait() overloads that take a lock and predicate, this
203     // form of wait() offers no safety for notify() signals that occur before
204     // the fiber is suspended, when signalling between different threads. For
205     // this reason, it is only ever recommended to use this overload if you can
206     // guarantee that the calls to wait() and notify() are made by the same
207     // thread.
208     //
209     // Use with extreme caution.
210     template <typename Clock, typename Duration>
211     MARL_NO_EXPORT inline bool wait(
212         const std::chrono::time_point<Clock, Duration>& timeout);
213 
214     // notify() reschedules the suspended Fiber for execution.
215     // notify() is usually only called when the predicate for one or more wait()
216     // calls will likely return true.
217     MARL_EXPORT
218     void notify();
219 
220     // id is the thread-unique identifier of the Fiber.
221     uint32_t const id;
222 
223    private:
224     friend class Allocator;
225     friend class Scheduler;
226 
227     enum class State {
228       // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
229       // ready to be recycled.
230       Idle,
231 
232       // Yielded: the Fiber is currently blocked on a wait() call with no
233       // timeout.
234       Yielded,
235 
236       // Waiting: the Fiber is currently blocked on a wait() call with a
237       // timeout. The fiber is stilling in the Worker::Work::waiting queue.
238       Waiting,
239 
240       // Queued: the Fiber is currently queued for execution in the
241       // Worker::Work::fibers queue.
242       Queued,
243 
244       // Running: the Fiber is currently executing.
245       Running,
246     };
247 
248     Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
249 
250     // switchTo() switches execution to the given fiber.
251     // switchTo() must only be called on the currently executing fiber.
252     void switchTo(Fiber*);
253 
254     // create() constructs and returns a new fiber with the given identifier,
255     // stack size and func that will be executed when switched to.
256     static Allocator::unique_ptr<Fiber> create(
257         Allocator* allocator,
258         uint32_t id,
259         size_t stackSize,
260         const std::function<void()>& func);
261 
262     // createFromCurrentThread() constructs and returns a new fiber with the
263     // given identifier for the current thread.
264     static Allocator::unique_ptr<Fiber> createFromCurrentThread(
265         Allocator* allocator,
266         uint32_t id);
267 
268     // toString() returns a string representation of the given State.
269     // Used for debugging.
270     static const char* toString(State state);
271 
272     Allocator::unique_ptr<OSFiber> const impl;
273     Worker* const worker;
274     State state = State::Running;  // Guarded by Worker's work.mutex.
275   };
276 
277  private:
278   Scheduler(const Scheduler&) = delete;
279   Scheduler(Scheduler&&) = delete;
280   Scheduler& operator=(const Scheduler&) = delete;
281   Scheduler& operator=(Scheduler&&) = delete;
282 
283   // Maximum number of worker threads.
284   static constexpr size_t MaxWorkerThreads = 256;
285 
286   // WaitingFibers holds all the fibers waiting on a timeout.
287   struct WaitingFibers {
288     inline WaitingFibers(Allocator*);
289 
290     // operator bool() returns true iff there are any wait fibers.
291     inline operator bool() const;
292 
293     // take() returns the next fiber that has exceeded its timeout, or nullptr
294     // if there are no fibers that have yet exceeded their timeouts.
295     inline Fiber* take(const TimePoint& timeout);
296 
297     // next() returns the timepoint of the next fiber to timeout.
298     // next() can only be called if operator bool() returns true.
299     inline TimePoint next() const;
300 
301     // add() adds another fiber and timeout to the list of waiting fibers.
302     inline void add(const TimePoint& timeout, Fiber* fiber);
303 
304     // erase() removes the fiber from the waiting list.
305     inline void erase(Fiber* fiber);
306 
307     // contains() returns true if fiber is waiting.
308     inline bool contains(Fiber* fiber) const;
309 
310    private:
311     struct Timeout {
312       TimePoint timepoint;
313       Fiber* fiber;
314       inline bool operator<(const Timeout&) const;
315     };
316     containers::set<Timeout, std::less<Timeout>> timeouts;
317     containers::unordered_map<Fiber*, TimePoint> fibers;
318   };
319 
320   // TODO: Implement a queue that recycles elements to reduce number of
321   // heap allocations.
322   using TaskQueue = containers::deque<Task>;
323   using FiberQueue = containers::deque<Fiber*>;
324   using FiberSet = containers::unordered_set<Fiber*>;
325 
326   // Workers execute Tasks on a single thread.
327   // Once a task is started, it may yield to other tasks on the same Worker.
328   // Tasks are always resumed by the same Worker.
329   class Worker {
330    public:
331     enum class Mode {
332       // Worker will spawn a background thread to process tasks.
333       MultiThreaded,
334 
335       // Worker will execute tasks whenever it yields.
336       SingleThreaded,
337     };
338 
339     Worker(Scheduler* scheduler, Mode mode, uint32_t id);
340 
341     // start() begins execution of the worker.
342     void start() EXCLUDES(work.mutex);
343 
344     // stop() ceases execution of the worker, blocking until all pending
345     // tasks have fully finished.
346     void stop() EXCLUDES(work.mutex);
347 
348     // wait() suspends execution of the current task until the predicate pred
349     // returns true or the optional timeout is reached.
350     // See Fiber::wait() for more information.
351     MARL_EXPORT
352     bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
353         EXCLUDES(work.mutex);
354 
355     // wait() suspends execution of the current task until the fiber is
356     // notified, or the optional timeout is reached.
357     // See Fiber::wait() for more information.
358     MARL_EXPORT
359     bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
360 
361     // suspend() suspends the currently executing Fiber until the fiber is
362     // woken with a call to enqueue(Fiber*), or automatically sometime after the
363     // optional timeout.
364     void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
365 
366     // enqueue(Fiber*) enqueues resuming of a suspended fiber.
367     void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
368 
369     // enqueue(Task&&) enqueues a new, unstarted task.
370     void enqueue(Task&& task) EXCLUDES(work.mutex);
371 
372     // tryLock() attempts to lock the worker for task enqueuing.
373     // If the lock was successful then true is returned, and the caller must
374     // call enqueueAndUnlock().
375     bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
376 
377     // enqueueAndUnlock() enqueues the task and unlocks the worker.
378     // Must only be called after a call to tryLock() which returned true.
379     // _Releases_lock_(work.mutex)
380     void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
381 
382     // runUntilShutdown() processes all tasks and fibers until there are no more
383     // and shutdown is true, upon runUntilShutdown() returns.
384     void runUntilShutdown() REQUIRES(work.mutex);
385 
386     // steal() attempts to steal a Task from the worker for another worker.
387     // Returns true if a task was taken and assigned to out, otherwise false.
388     bool steal(Task& out) EXCLUDES(work.mutex);
389 
390     // getCurrent() returns the Worker currently bound to the current
391     // thread.
392     static inline Worker* getCurrent();
393 
394     // getCurrentFiber() returns the Fiber currently being executed.
395     inline Fiber* getCurrentFiber() const;
396 
397     // Unique identifier of the Worker.
398     const uint32_t id;
399 
400    private:
401     // run() is the task processing function for the worker.
402     // run() processes tasks until stop() is called.
403     void run() REQUIRES(work.mutex);
404 
405     // createWorkerFiber() creates a new fiber that when executed calls
406     // run().
407     Fiber* createWorkerFiber() REQUIRES(work.mutex);
408 
409     // switchToFiber() switches execution to the given fiber. The fiber
410     // must belong to this worker.
411     void switchToFiber(Fiber*) REQUIRES(work.mutex);
412 
413     // runUntilIdle() executes all pending tasks and then returns.
414     void runUntilIdle() REQUIRES(work.mutex);
415 
416     // waitForWork() blocks until new work is available, potentially calling
417     // spinForWork().
418     void waitForWork() REQUIRES(work.mutex);
419 
420     // spinForWork() attempts to steal work from another Worker, and keeps
421     // the thread awake for a short duration. This reduces overheads of
422     // frequently putting the thread to sleep and re-waking.
423     void spinForWork();
424 
425     // enqueueFiberTimeouts() enqueues all the fibers that have finished
426     // waiting.
427     void enqueueFiberTimeouts() REQUIRES(work.mutex);
428 
429     inline void changeFiberState(Fiber* fiber,
430                                  Fiber::State from,
431                                  Fiber::State to) const REQUIRES(work.mutex);
432 
433     inline void setFiberState(Fiber* fiber, Fiber::State to) const
434         REQUIRES(work.mutex);
435 
436     // Work holds tasks and fibers that are enqueued on the Worker.
437     struct Work {
438       inline Work(Allocator*);
439 
440       std::atomic<uint64_t> num = {0};  // tasks.size() + fibers.size()
441       GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
442       GUARDED_BY(mutex) TaskQueue tasks;
443       GUARDED_BY(mutex) FiberQueue fibers;
444       GUARDED_BY(mutex) WaitingFibers waiting;
445       GUARDED_BY(mutex) bool notifyAdded = true;
446       std::condition_variable added;
447       marl::mutex mutex;
448 
449       template <typename F>
450       inline void wait(F&&) REQUIRES(mutex);
451     };
452 
453     // https://en.wikipedia.org/wiki/Xorshift
454     class FastRnd {
455      public:
operator()456       inline uint64_t operator()() {
457         x ^= x << 13;
458         x ^= x >> 7;
459         x ^= x << 17;
460         return x;
461       }
462 
463      private:
464       uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
465     };
466 
467     // The current worker bound to the current thread.
468     MARL_DECLARE_THREAD_LOCAL(Worker*, current);
469 
470     Mode const mode;
471     Scheduler* const scheduler;
472     Allocator::unique_ptr<Fiber> mainFiber;
473     Fiber* currentFiber = nullptr;
474     Thread thread;
475     Work work;
476     FiberSet idleFibers;  // Fibers that have completed which can be reused.
477     containers::vector<Allocator::unique_ptr<Fiber>, 16>
478         workerFibers;  // All fibers created by this worker.
479     FastRnd rng;
480     bool shutdown = false;
481   };
482 
483   // stealWork() attempts to steal a task from the worker with the given id.
484   // Returns true if a task was stolen and assigned to out, otherwise false.
485   bool stealWork(Worker* thief, uint64_t from, Task& out);
486 
487   // onBeginSpinning() is called when a Worker calls spinForWork().
488   // The scheduler will prioritize this worker for new tasks to try to prevent
489   // it going to sleep.
490   void onBeginSpinning(int workerId);
491 
492   // setBound() sets the scheduler bound to the current thread.
493   static void setBound(Scheduler* scheduler);
494 
495   // The scheduler currently bound to the current thread.
496   MARL_DECLARE_THREAD_LOCAL(Scheduler*, bound);
497 
498   // The immutable configuration used to build the scheduler.
499   const Config cfg;
500 
501   std::array<std::atomic<int>, 8> spinningWorkers;
502   std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
503 
504   std::atomic<unsigned int> nextEnqueueIndex = {0};
505   std::array<Worker*, MaxWorkerThreads> workerThreads;
506 
507   struct SingleThreadedWorkers {
508     inline SingleThreadedWorkers(Allocator*);
509 
510     using WorkerByTid =
511         containers::unordered_map<std::thread::id,
512                                   Allocator::unique_ptr<Worker>>;
513     marl::mutex mutex;
514     GUARDED_BY(mutex) std::condition_variable unbind;
515     GUARDED_BY(mutex) WorkerByTid byTid;
516   };
517   SingleThreadedWorkers singleThreadedWorkers;
518 };
519 
520 ////////////////////////////////////////////////////////////////////////////////
521 // Scheduler::Config
522 ////////////////////////////////////////////////////////////////////////////////
setAllocator(Allocator * alloc)523 Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
524   allocator = alloc;
525   return *this;
526 }
527 
setFiberStackSize(size_t size)528 Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
529   fiberStackSize = size;
530   return *this;
531 }
532 
setWorkerThreadCount(int count)533 Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
534   workerThread.count = count;
535   return *this;
536 }
537 
setWorkerThreadInitializer(const ThreadInitializer & initializer)538 Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
539     const ThreadInitializer& initializer) {
540   workerThread.initializer = initializer;
541   return *this;
542 }
543 
setWorkerThreadAffinityPolicy(const std::shared_ptr<Thread::Affinity::Policy> & policy)544 Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
545     const std::shared_ptr<Thread::Affinity::Policy>& policy) {
546   workerThread.affinityPolicy = policy;
547   return *this;
548 }
549 
550 ////////////////////////////////////////////////////////////////////////////////
551 // Scheduler::Fiber
552 ////////////////////////////////////////////////////////////////////////////////
553 template <typename Clock, typename Duration>
wait(marl::lock & lock,const std::chrono::time_point<Clock,Duration> & timeout,const Predicate & pred)554 bool Scheduler::Fiber::wait(
555     marl::lock& lock,
556     const std::chrono::time_point<Clock, Duration>& timeout,
557     const Predicate& pred) {
558   using ToDuration = typename TimePoint::duration;
559   using ToClock = typename TimePoint::clock;
560   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
561   return worker->wait(lock, &tp, pred);
562 }
563 
wait()564 void Scheduler::Fiber::wait() {
565   worker->wait(nullptr);
566 }
567 
568 template <typename Clock, typename Duration>
wait(const std::chrono::time_point<Clock,Duration> & timeout)569 bool Scheduler::Fiber::wait(
570     const std::chrono::time_point<Clock, Duration>& timeout) {
571   using ToDuration = typename TimePoint::duration;
572   using ToClock = typename TimePoint::clock;
573   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
574   return worker->wait(&tp);
575 }
576 
getCurrent()577 Scheduler::Worker* Scheduler::Worker::getCurrent() {
578   return Worker::current;
579 }
580 
getCurrentFiber()581 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
582   return currentFiber;
583 }
584 
585 // schedule() schedules the task T to be asynchronously called using the
586 // currently bound scheduler.
schedule(Task && t)587 inline void schedule(Task&& t) {
588   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
589   auto scheduler = Scheduler::get();
590   scheduler->enqueue(std::move(t));
591 }
592 
593 // schedule() schedules the function f to be asynchronously called with the
594 // given arguments using the currently bound scheduler.
595 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)596 inline void schedule(Function&& f, Args&&... args) {
597   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
598   auto scheduler = Scheduler::get();
599   scheduler->enqueue(
600       Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
601 }
602 
603 // schedule() schedules the function f to be asynchronously called using the
604 // currently bound scheduler.
605 template <typename Function>
schedule(Function && f)606 inline void schedule(Function&& f) {
607   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
608   auto scheduler = Scheduler::get();
609   scheduler->enqueue(Task(std::forward<Function>(f)));
610 }
611 
612 }  // namespace marl
613 
614 #endif  // marl_scheduler_h
615