• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17 
18 #include "debug.h"
19 #include "memory.h"
20 #include "sal.h"
21 
22 #include <array>
23 #include <atomic>
24 #include <chrono>
25 #include <condition_variable>
26 #include <functional>
27 #include <map>
28 #include <mutex>
29 #include <queue>
30 #include <set>
31 #include <thread>
32 #include <unordered_map>
33 #include <unordered_set>
34 
35 namespace marl {
36 
37 class OSFiber;
38 
39 // Task is a unit of work for the scheduler.
40 using Task = std::function<void()>;
41 
42 // Scheduler asynchronously processes Tasks.
43 // A scheduler can be bound to one or more threads using the bind() method.
44 // Once bound to a thread, that thread can call marl::schedule() to enqueue
45 // work tasks to be executed asynchronously.
46 // All threads must be unbound with unbind() before the scheduler is destructed.
47 // Scheduler are initially constructed in single-threaded mode.
48 // Call setWorkerThreadCount() to spawn dedicated worker threads.
49 class Scheduler {
50   class Worker;
51 
52  public:
53   using TimePoint = std::chrono::system_clock::time_point;
54   using Predicate = std::function<bool()>;
55 
56   Scheduler(Allocator* allocator = Allocator::Default);
57 
58   // Destructor.
59   // Ensure that all threads are unbound before calling - failure to do so may
60   // result in leaked memory.
61   ~Scheduler();
62 
63   // get() returns the scheduler bound to the current thread.
64   static Scheduler* get();
65 
66   // bind() binds this scheduler to the current thread.
67   // There must be no existing scheduler bound to the thread prior to calling.
68   void bind();
69 
70   // unbind() unbinds the scheduler currently bound to the current thread.
71   // There must be a existing scheduler bound to the thread prior to calling.
72   // unbind() flushes any enqueued tasks on the single-threaded worker before
73   // returning.
74   static void unbind();
75 
76   // enqueue() queues the task for asynchronous execution.
77   void enqueue(Task&& task);
78 
79   // setThreadInitializer() sets the worker thread initializer function which
80   // will be called for each new worker thread spawned.
81   // The initializer will only be called on newly created threads (call
82   // setThreadInitializer() before setWorkerThreadCount()).
83   void setThreadInitializer(const std::function<void()>& init);
84 
85   // getThreadInitializer() returns the thread initializer function set by
86   // setThreadInitializer().
87   const std::function<void()>& getThreadInitializer();
88 
89   // setWorkerThreadCount() adjusts the number of dedicated worker threads.
90   // A count of 0 puts the scheduler into single-threaded mode.
91   // Note: Currently the number of threads cannot be adjusted once tasks
92   // have been enqueued. This restriction may be lifted at a later time.
93   void setWorkerThreadCount(int count);
94 
95   // getWorkerThreadCount() returns the number of worker threads.
96   int getWorkerThreadCount();
97 
98   // Fibers expose methods to perform cooperative multitasking and are
99   // automatically created by the Scheduler.
100   //
101   // The currently executing Fiber can be obtained by calling Fiber::current().
102   //
103   // When execution becomes blocked, yield() can be called to suspend execution
104   // of the fiber and start executing other pending work. Once the block has
105   // been lifted, schedule() can be called to reschedule the Fiber on the same
106   // thread that previously executed it.
107   class Fiber {
108    public:
109     using Lock = std::unique_lock<std::mutex>;
110 
111     // current() returns the currently executing fiber, or nullptr if called
112     // without a bound scheduler.
113     static Fiber* current();
114 
115     // wait() suspends execution of this Fiber until the Fiber is woken up with
116     // a call to notify() and the predicate pred returns true.
117     // If the predicate pred does not return true when notify() is called, then
118     // the Fiber is automatically re-suspended, and will need to be woken with
119     // another call to notify().
120     // While the Fiber is suspended, the scheduler thread may continue executing
121     // other tasks.
122     // lock must be locked before calling, and is unlocked by wait() just before
123     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
124     // will be locked before wait() returns.
125     // pred will be always be called with the lock held.
126     // wait() must only be called on the currently executing fiber.
127     _Requires_lock_held_(lock)
128     void wait(Lock& lock, const Predicate& pred);
129 
130     // wait() suspends execution of this Fiber until the Fiber is woken up with
131     // a call to notify() and the predicate pred returns true, or sometime after
132     // the timeout is reached.
133     // If the predicate pred does not return true when notify() is called, then
134     // the Fiber is automatically re-suspended, and will need to be woken with
135     // another call to notify() or will be woken sometime after the timeout is
136     // reached.
137     // While the Fiber is suspended, the scheduler thread may continue executing
138     // other tasks.
139     // lock must be locked before calling, and is unlocked by wait() just before
140     // the Fiber is suspended, and re-locked before the fiber is resumed. lock
141     // will be locked before wait() returns.
142     // pred will be always be called with the lock held.
143     // wait() must only be called on the currently executing fiber.
144     _Requires_lock_held_(lock)
145     template <typename Clock, typename Duration>
146     inline bool wait(Lock& lock,
147                      const std::chrono::time_point<Clock, Duration>& timeout,
148                      const Predicate& pred);
149 
150     // notify() reschedules the suspended Fiber for execution.
151     // notify() is usually only called when the predicate for one or more wait()
152     // calls will likely return true.
153     void notify();
154 
155     // id is the thread-unique identifier of the Fiber.
156     uint32_t const id;
157 
158    private:
159     friend class Allocator;
160     friend class Scheduler;
161 
162     enum class State {
163       // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
164       // ready to be recycled.
165       Idle,
166 
167       // Yielded: the Fiber is currently blocked on a wait() call with no
168       // timeout.
169       Yielded,
170 
171       // Waiting: the Fiber is currently blocked on a wait() call with a
172       // timeout. The fiber is stilling in the Worker::Work::waiting queue.
173       Waiting,
174 
175       // Queued: the Fiber is currently queued for execution in the
176       // Worker::Work::fibers queue.
177       Queued,
178 
179       // Running: the Fiber is currently executing.
180       Running,
181     };
182 
183     Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
184 
185     // switchTo() switches execution to the given fiber.
186     // switchTo() must only be called on the currently executing fiber.
187     void switchTo(Fiber*);
188 
189     // create() constructs and returns a new fiber with the given identifier,
190     // stack size that will executed func when switched to.
191     static Allocator::unique_ptr<Fiber> create(
192         Allocator* allocator,
193         uint32_t id,
194         size_t stackSize,
195         const std::function<void()>& func);
196 
197     // createFromCurrentThread() constructs and returns a new fiber with the
198     // given identifier for the current thread.
199     static Allocator::unique_ptr<Fiber> createFromCurrentThread(
200         Allocator* allocator,
201         uint32_t id);
202 
203     // toString() returns a string representation of the given State.
204     // Used for debugging.
205     static const char* toString(State state);
206 
207     Allocator::unique_ptr<OSFiber> const impl;
208     Worker* const worker;
209     State state = State::Running;  // Guarded by Worker's work.mutex.
210   };
211 
212  private:
213   Scheduler(const Scheduler&) = delete;
214   Scheduler(Scheduler&&) = delete;
215   Scheduler& operator=(const Scheduler&) = delete;
216   Scheduler& operator=(Scheduler&&) = delete;
217 
218   // Stack size in bytes of a new fiber.
219   // TODO: Make configurable so the default size can be reduced.
220   static constexpr size_t FiberStackSize = 1024 * 1024;
221 
222   // Maximum number of worker threads.
223   static constexpr size_t MaxWorkerThreads = 256;
224 
225   // WaitingFibers holds all the fibers waiting on a timeout.
226   struct WaitingFibers {
227     // operator bool() returns true iff there are any wait fibers.
228     inline operator bool() const;
229 
230     // take() returns the next fiber that has exceeded its timeout, or nullptr
231     // if there are no fibers that have yet exceeded their timeouts.
232     inline Fiber* take(const TimePoint& timepoint);
233 
234     // next() returns the timepoint of the next fiber to timeout.
235     // next() can only be called if operator bool() returns true.
236     inline TimePoint next() const;
237 
238     // add() adds another fiber and timeout to the list of waiting fibers.
239     inline void add(const TimePoint& timeout, Fiber* fiber);
240 
241     // erase() removes the fiber from the waiting list.
242     inline void erase(Fiber* fiber);
243 
244     // contains() returns true if fiber is waiting.
245     inline bool contains(Fiber* fiber) const;
246 
247    private:
248     struct Timeout {
249       TimePoint timepoint;
250       Fiber* fiber;
251       inline bool operator<(const Timeout&) const;
252     };
253     std::set<Timeout> timeouts;
254     std::unordered_map<Fiber*, TimePoint> fibers;
255   };
256 
257   // TODO: Implement a queue that recycles elements to reduce number of
258   // heap allocations.
259   using TaskQueue = std::queue<Task>;
260   using FiberQueue = std::queue<Fiber*>;
261   using FiberSet = std::unordered_set<Fiber*>;
262 
263   // Workers executes Tasks on a single thread.
264   // Once a task is started, it may yield to other tasks on the same Worker.
265   // Tasks are always resumed by the same Worker.
266   class Worker {
267    public:
268     enum class Mode {
269       // Worker will spawn a background thread to process tasks.
270       MultiThreaded,
271 
272       // Worker will execute tasks whenever it yields.
273       SingleThreaded,
274     };
275 
276     Worker(Scheduler* scheduler, Mode mode, uint32_t id);
277 
278     // start() begins execution of the worker.
279     void start();
280 
281     // stop() ceases execution of the worker, blocking until all pending
282     // tasks have fully finished.
283     void stop();
284 
285     // wait() suspends execution of the current task until the predicate pred
286     // returns true.
287     // See Fiber::wait() for more information.
288     _Requires_lock_held_(lock)
289     bool wait(Fiber::Lock& lock,
290               const TimePoint* timeout,
291               const Predicate& pred);
292 
293     // suspend() suspends the currenetly executing Fiber until the fiber is
294     // woken with a call to enqueue(Fiber*), or automatically sometime after the
295     // optional timeout.
296     _Requires_lock_held_(work.mutex)
297     void suspend(const TimePoint* timeout);
298 
299     // enqueue(Fiber*) enqueues resuming of a suspended fiber.
300     void enqueue(Fiber* fiber);
301 
302     // enqueue(Task&&) enqueues a new, unstarted task.
303     void enqueue(Task&& task);
304 
305     // tryLock() attempts to lock the worker for task enqueing.
306     // If the lock was successful then true is returned, and the caller must
307     // call enqueueAndUnlock().
308     _When_(return == true, _Acquires_lock_(work.mutex))
309     bool tryLock();
310 
311     // enqueueAndUnlock() enqueues the task and unlocks the worker.
312     // Must only be called after a call to tryLock() which returned true.
313     _Requires_lock_held_(work.mutex)
314     _Releases_lock_(work.mutex)
315     void enqueueAndUnlock(Task&& task);
316 
317     // flush() processes all pending tasks before returning.
318     void flush();
319 
320     // dequeue() attempts to take a Task from the worker. Returns true if
321     // a task was taken and assigned to out, otherwise false.
322     bool dequeue(Task& out);
323 
324     // getCurrent() returns the Worker currently bound to the current
325     // thread.
326     static inline Worker* getCurrent();
327 
328     // getCurrentFiber() returns the Fiber currently being executed.
329     inline Fiber* getCurrentFiber() const;
330 
331     // Unique identifier of the Worker.
332     const uint32_t id;
333 
334    private:
335     // run() is the task processing function for the worker.
336     // If the worker was constructed in Mode::MultiThreaded, run() will
337     // continue to process tasks until stop() is called.
338     // If the worker was constructed in Mode::SingleThreaded, run() call
339     // flush() and return.
340     void run();
341 
342     // createWorkerFiber() creates a new fiber that when executed calls
343     // run().
344     Fiber* createWorkerFiber();
345 
346     // switchToFiber() switches execution to the given fiber. The fiber
347     // must belong to this worker.
348     void switchToFiber(Fiber*);
349 
350     // runUntilIdle() executes all pending tasks and then returns.
351     _Requires_lock_held_(work.mutex)
352     void runUntilIdle();
353 
354     // waitForWork() blocks until new work is available, potentially calling
355     // spinForWork().
356     _Requires_lock_held_(work.mutex)
357     void waitForWork();
358 
359     // spinForWork() attempts to steal work from another Worker, and keeps
360     // the thread awake for a short duration. This reduces overheads of
361     // frequently putting the thread to sleep and re-waking.
362     void spinForWork();
363 
364     // enqueueFiberTimeouts() enqueues all the fibers that have finished
365     // waiting.
366     _Requires_lock_held_(work.mutex)
367     void enqueueFiberTimeouts();
368 
369     _Requires_lock_held_(work.mutex)
370     inline void changeFiberState(Fiber* fiber,
371                                  Fiber::State from,
372                                  Fiber::State to) const;
373 
374     _Requires_lock_held_(work.mutex)
375     inline void setFiberState(Fiber* fiber, Fiber::State to) const;
376 
377     // numBlockedFibers() returns the number of fibers currently blocked and
378     // held externally.
numBlockedFibers()379     inline size_t numBlockedFibers() const {
380       return workerFibers.size() - idleFibers.size();
381     }
382 
383     // Work holds tasks and fibers that are enqueued on the Worker.
384     struct Work {
385       std::atomic<uint64_t> num = {0};  // tasks.size() + fibers.size()
386       _Guarded_by_(mutex) TaskQueue tasks;
387       _Guarded_by_(mutex) FiberQueue fibers;
388       _Guarded_by_(mutex) WaitingFibers waiting;
389       std::condition_variable added;
390       std::mutex mutex;
391     };
392 
393     // https://en.wikipedia.org/wiki/Xorshift
394     class FastRnd {
395      public:
operator()396       inline uint64_t operator()() {
397         x ^= x << 13;
398         x ^= x >> 7;
399         x ^= x << 17;
400         return x;
401       }
402 
403      private:
404       uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
405     };
406 
407     // The current worker bound to the current thread.
408     static thread_local Worker* current;
409 
410     Mode const mode;
411     Scheduler* const scheduler;
412     Allocator::unique_ptr<Fiber> mainFiber;
413     Fiber* currentFiber = nullptr;
414     std::thread thread;
415     Work work;
416     FiberSet idleFibers;  // Fibers that have completed which can be reused.
417     std::vector<Allocator::unique_ptr<Fiber>>
418         workerFibers;  // All fibers created by this worker.
419     FastRnd rng;
420     std::atomic<bool> shutdown = {false};
421   };
422 
423   // stealWork() attempts to steal a task from the worker with the given id.
424   // Returns true if a task was stolen and assigned to out, otherwise false.
425   bool stealWork(Worker* thief, uint64_t from, Task& out);
426 
427   // onBeginSpinning() is called when a Worker calls spinForWork().
428   // The scheduler will prioritize this worker for new tasks to try to prevent
429   // it going to sleep.
430   void onBeginSpinning(int workerId);
431 
432   // The scheduler currently bound to the current thread.
433   static thread_local Scheduler* bound;
434 
435   Allocator* const allocator;
436 
437   std::function<void()> threadInitFunc;
438   std::mutex threadInitFuncMutex;
439 
440   std::array<std::atomic<int>, 8> spinningWorkers;
441   std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
442 
443   // TODO: Make this lot thread-safe so setWorkerThreadCount() can be called
444   // during execution of tasks.
445   std::atomic<unsigned int> nextEnqueueIndex = {0};
446   unsigned int numWorkerThreads = 0;
447   std::array<Worker*, MaxWorkerThreads> workerThreads;
448 
449   std::mutex singleThreadedWorkerMutex;
450   std::unordered_map<std::thread::id, Allocator::unique_ptr<Worker>>
451       singleThreadedWorkers;
452 };
453 
_Requires_lock_held_(lock)454 _Requires_lock_held_(lock)
455 template <typename Clock, typename Duration>
456 bool Scheduler::Fiber::wait(
457     Lock& lock,
458     const std::chrono::time_point<Clock, Duration>& timeout,
459     const Predicate& pred) {
460   using ToDuration = typename TimePoint::duration;
461   using ToClock = typename TimePoint::clock;
462   auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
463   return worker->wait(lock, &tp, pred);
464 }
465 
getCurrent()466 Scheduler::Worker* Scheduler::Worker::getCurrent() {
467   return Worker::current;
468 }
469 
getCurrentFiber()470 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
471   return currentFiber;
472 }
473 
474 // schedule() schedules the function f to be asynchronously called with the
475 // given arguments using the currently bound scheduler.
476 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)477 inline void schedule(Function&& f, Args&&... args) {
478   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
479   auto scheduler = Scheduler::get();
480   scheduler->enqueue(
481       std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
482 }
483 
484 // schedule() schedules the function f to be asynchronously called using the
485 // currently bound scheduler.
486 template <typename Function>
schedule(Function && f)487 inline void schedule(Function&& f) {
488   MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
489   auto scheduler = Scheduler::get();
490   scheduler->enqueue(std::forward<Function>(f));
491 }
492 
493 }  // namespace marl
494 
495 #endif  // marl_scheduler_h
496