1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_scheduler_h
16 #define marl_scheduler_h
17
18 #include "debug.h"
19 #include "memory.h"
20 #include "sal.h"
21
22 #include <array>
23 #include <atomic>
24 #include <chrono>
25 #include <condition_variable>
26 #include <functional>
27 #include <map>
28 #include <mutex>
29 #include <queue>
30 #include <set>
31 #include <thread>
32 #include <unordered_map>
33 #include <unordered_set>
34
35 namespace marl {
36
37 class OSFiber;
38
39 // Task is a unit of work for the scheduler.
40 using Task = std::function<void()>;
41
42 // Scheduler asynchronously processes Tasks.
43 // A scheduler can be bound to one or more threads using the bind() method.
44 // Once bound to a thread, that thread can call marl::schedule() to enqueue
45 // work tasks to be executed asynchronously.
46 // All threads must be unbound with unbind() before the scheduler is destructed.
47 // Scheduler are initially constructed in single-threaded mode.
48 // Call setWorkerThreadCount() to spawn dedicated worker threads.
49 class Scheduler {
50 class Worker;
51
52 public:
53 using TimePoint = std::chrono::system_clock::time_point;
54 using Predicate = std::function<bool()>;
55
56 Scheduler(Allocator* allocator = Allocator::Default);
57
58 // Destructor.
59 // Ensure that all threads are unbound before calling - failure to do so may
60 // result in leaked memory.
61 ~Scheduler();
62
63 // get() returns the scheduler bound to the current thread.
64 static Scheduler* get();
65
66 // bind() binds this scheduler to the current thread.
67 // There must be no existing scheduler bound to the thread prior to calling.
68 void bind();
69
70 // unbind() unbinds the scheduler currently bound to the current thread.
71 // There must be a existing scheduler bound to the thread prior to calling.
72 // unbind() flushes any enqueued tasks on the single-threaded worker before
73 // returning.
74 static void unbind();
75
76 // enqueue() queues the task for asynchronous execution.
77 void enqueue(Task&& task);
78
79 // setThreadInitializer() sets the worker thread initializer function which
80 // will be called for each new worker thread spawned.
81 // The initializer will only be called on newly created threads (call
82 // setThreadInitializer() before setWorkerThreadCount()).
83 void setThreadInitializer(const std::function<void()>& init);
84
85 // getThreadInitializer() returns the thread initializer function set by
86 // setThreadInitializer().
87 const std::function<void()>& getThreadInitializer();
88
89 // setWorkerThreadCount() adjusts the number of dedicated worker threads.
90 // A count of 0 puts the scheduler into single-threaded mode.
91 // Note: Currently the number of threads cannot be adjusted once tasks
92 // have been enqueued. This restriction may be lifted at a later time.
93 void setWorkerThreadCount(int count);
94
95 // getWorkerThreadCount() returns the number of worker threads.
96 int getWorkerThreadCount();
97
98 // Fibers expose methods to perform cooperative multitasking and are
99 // automatically created by the Scheduler.
100 //
101 // The currently executing Fiber can be obtained by calling Fiber::current().
102 //
103 // When execution becomes blocked, yield() can be called to suspend execution
104 // of the fiber and start executing other pending work. Once the block has
105 // been lifted, schedule() can be called to reschedule the Fiber on the same
106 // thread that previously executed it.
107 class Fiber {
108 public:
109 using Lock = std::unique_lock<std::mutex>;
110
111 // current() returns the currently executing fiber, or nullptr if called
112 // without a bound scheduler.
113 static Fiber* current();
114
115 // wait() suspends execution of this Fiber until the Fiber is woken up with
116 // a call to notify() and the predicate pred returns true.
117 // If the predicate pred does not return true when notify() is called, then
118 // the Fiber is automatically re-suspended, and will need to be woken with
119 // another call to notify().
120 // While the Fiber is suspended, the scheduler thread may continue executing
121 // other tasks.
122 // lock must be locked before calling, and is unlocked by wait() just before
123 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
124 // will be locked before wait() returns.
125 // pred will be always be called with the lock held.
126 // wait() must only be called on the currently executing fiber.
127 _Requires_lock_held_(lock)
128 void wait(Lock& lock, const Predicate& pred);
129
130 // wait() suspends execution of this Fiber until the Fiber is woken up with
131 // a call to notify() and the predicate pred returns true, or sometime after
132 // the timeout is reached.
133 // If the predicate pred does not return true when notify() is called, then
134 // the Fiber is automatically re-suspended, and will need to be woken with
135 // another call to notify() or will be woken sometime after the timeout is
136 // reached.
137 // While the Fiber is suspended, the scheduler thread may continue executing
138 // other tasks.
139 // lock must be locked before calling, and is unlocked by wait() just before
140 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
141 // will be locked before wait() returns.
142 // pred will be always be called with the lock held.
143 // wait() must only be called on the currently executing fiber.
144 _Requires_lock_held_(lock)
145 template <typename Clock, typename Duration>
146 inline bool wait(Lock& lock,
147 const std::chrono::time_point<Clock, Duration>& timeout,
148 const Predicate& pred);
149
150 // notify() reschedules the suspended Fiber for execution.
151 // notify() is usually only called when the predicate for one or more wait()
152 // calls will likely return true.
153 void notify();
154
155 // id is the thread-unique identifier of the Fiber.
156 uint32_t const id;
157
158 private:
159 friend class Allocator;
160 friend class Scheduler;
161
162 enum class State {
163 // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
164 // ready to be recycled.
165 Idle,
166
167 // Yielded: the Fiber is currently blocked on a wait() call with no
168 // timeout.
169 Yielded,
170
171 // Waiting: the Fiber is currently blocked on a wait() call with a
172 // timeout. The fiber is stilling in the Worker::Work::waiting queue.
173 Waiting,
174
175 // Queued: the Fiber is currently queued for execution in the
176 // Worker::Work::fibers queue.
177 Queued,
178
179 // Running: the Fiber is currently executing.
180 Running,
181 };
182
183 Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
184
185 // switchTo() switches execution to the given fiber.
186 // switchTo() must only be called on the currently executing fiber.
187 void switchTo(Fiber*);
188
189 // create() constructs and returns a new fiber with the given identifier,
190 // stack size that will executed func when switched to.
191 static Allocator::unique_ptr<Fiber> create(
192 Allocator* allocator,
193 uint32_t id,
194 size_t stackSize,
195 const std::function<void()>& func);
196
197 // createFromCurrentThread() constructs and returns a new fiber with the
198 // given identifier for the current thread.
199 static Allocator::unique_ptr<Fiber> createFromCurrentThread(
200 Allocator* allocator,
201 uint32_t id);
202
203 // toString() returns a string representation of the given State.
204 // Used for debugging.
205 static const char* toString(State state);
206
207 Allocator::unique_ptr<OSFiber> const impl;
208 Worker* const worker;
209 State state = State::Running; // Guarded by Worker's work.mutex.
210 };
211
212 private:
213 Scheduler(const Scheduler&) = delete;
214 Scheduler(Scheduler&&) = delete;
215 Scheduler& operator=(const Scheduler&) = delete;
216 Scheduler& operator=(Scheduler&&) = delete;
217
218 // Stack size in bytes of a new fiber.
219 // TODO: Make configurable so the default size can be reduced.
220 static constexpr size_t FiberStackSize = 1024 * 1024;
221
222 // Maximum number of worker threads.
223 static constexpr size_t MaxWorkerThreads = 256;
224
225 // WaitingFibers holds all the fibers waiting on a timeout.
226 struct WaitingFibers {
227 // operator bool() returns true iff there are any wait fibers.
228 inline operator bool() const;
229
230 // take() returns the next fiber that has exceeded its timeout, or nullptr
231 // if there are no fibers that have yet exceeded their timeouts.
232 inline Fiber* take(const TimePoint& timepoint);
233
234 // next() returns the timepoint of the next fiber to timeout.
235 // next() can only be called if operator bool() returns true.
236 inline TimePoint next() const;
237
238 // add() adds another fiber and timeout to the list of waiting fibers.
239 inline void add(const TimePoint& timeout, Fiber* fiber);
240
241 // erase() removes the fiber from the waiting list.
242 inline void erase(Fiber* fiber);
243
244 // contains() returns true if fiber is waiting.
245 inline bool contains(Fiber* fiber) const;
246
247 private:
248 struct Timeout {
249 TimePoint timepoint;
250 Fiber* fiber;
251 inline bool operator<(const Timeout&) const;
252 };
253 std::set<Timeout> timeouts;
254 std::unordered_map<Fiber*, TimePoint> fibers;
255 };
256
257 // TODO: Implement a queue that recycles elements to reduce number of
258 // heap allocations.
259 using TaskQueue = std::queue<Task>;
260 using FiberQueue = std::queue<Fiber*>;
261 using FiberSet = std::unordered_set<Fiber*>;
262
263 // Workers executes Tasks on a single thread.
264 // Once a task is started, it may yield to other tasks on the same Worker.
265 // Tasks are always resumed by the same Worker.
266 class Worker {
267 public:
268 enum class Mode {
269 // Worker will spawn a background thread to process tasks.
270 MultiThreaded,
271
272 // Worker will execute tasks whenever it yields.
273 SingleThreaded,
274 };
275
276 Worker(Scheduler* scheduler, Mode mode, uint32_t id);
277
278 // start() begins execution of the worker.
279 void start();
280
281 // stop() ceases execution of the worker, blocking until all pending
282 // tasks have fully finished.
283 void stop();
284
285 // wait() suspends execution of the current task until the predicate pred
286 // returns true.
287 // See Fiber::wait() for more information.
288 _Requires_lock_held_(lock)
289 bool wait(Fiber::Lock& lock,
290 const TimePoint* timeout,
291 const Predicate& pred);
292
293 // suspend() suspends the currenetly executing Fiber until the fiber is
294 // woken with a call to enqueue(Fiber*), or automatically sometime after the
295 // optional timeout.
296 _Requires_lock_held_(work.mutex)
297 void suspend(const TimePoint* timeout);
298
299 // enqueue(Fiber*) enqueues resuming of a suspended fiber.
300 void enqueue(Fiber* fiber);
301
302 // enqueue(Task&&) enqueues a new, unstarted task.
303 void enqueue(Task&& task);
304
305 // tryLock() attempts to lock the worker for task enqueing.
306 // If the lock was successful then true is returned, and the caller must
307 // call enqueueAndUnlock().
308 _When_(return == true, _Acquires_lock_(work.mutex))
309 bool tryLock();
310
311 // enqueueAndUnlock() enqueues the task and unlocks the worker.
312 // Must only be called after a call to tryLock() which returned true.
313 _Requires_lock_held_(work.mutex)
314 _Releases_lock_(work.mutex)
315 void enqueueAndUnlock(Task&& task);
316
317 // flush() processes all pending tasks before returning.
318 void flush();
319
320 // dequeue() attempts to take a Task from the worker. Returns true if
321 // a task was taken and assigned to out, otherwise false.
322 bool dequeue(Task& out);
323
324 // getCurrent() returns the Worker currently bound to the current
325 // thread.
326 static inline Worker* getCurrent();
327
328 // getCurrentFiber() returns the Fiber currently being executed.
329 inline Fiber* getCurrentFiber() const;
330
331 // Unique identifier of the Worker.
332 const uint32_t id;
333
334 private:
335 // run() is the task processing function for the worker.
336 // If the worker was constructed in Mode::MultiThreaded, run() will
337 // continue to process tasks until stop() is called.
338 // If the worker was constructed in Mode::SingleThreaded, run() call
339 // flush() and return.
340 void run();
341
342 // createWorkerFiber() creates a new fiber that when executed calls
343 // run().
344 Fiber* createWorkerFiber();
345
346 // switchToFiber() switches execution to the given fiber. The fiber
347 // must belong to this worker.
348 void switchToFiber(Fiber*);
349
350 // runUntilIdle() executes all pending tasks and then returns.
351 _Requires_lock_held_(work.mutex)
352 void runUntilIdle();
353
354 // waitForWork() blocks until new work is available, potentially calling
355 // spinForWork().
356 _Requires_lock_held_(work.mutex)
357 void waitForWork();
358
359 // spinForWork() attempts to steal work from another Worker, and keeps
360 // the thread awake for a short duration. This reduces overheads of
361 // frequently putting the thread to sleep and re-waking.
362 void spinForWork();
363
364 // enqueueFiberTimeouts() enqueues all the fibers that have finished
365 // waiting.
366 _Requires_lock_held_(work.mutex)
367 void enqueueFiberTimeouts();
368
369 _Requires_lock_held_(work.mutex)
370 inline void changeFiberState(Fiber* fiber,
371 Fiber::State from,
372 Fiber::State to) const;
373
374 _Requires_lock_held_(work.mutex)
375 inline void setFiberState(Fiber* fiber, Fiber::State to) const;
376
377 // numBlockedFibers() returns the number of fibers currently blocked and
378 // held externally.
numBlockedFibers()379 inline size_t numBlockedFibers() const {
380 return workerFibers.size() - idleFibers.size();
381 }
382
383 // Work holds tasks and fibers that are enqueued on the Worker.
384 struct Work {
385 std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
386 _Guarded_by_(mutex) TaskQueue tasks;
387 _Guarded_by_(mutex) FiberQueue fibers;
388 _Guarded_by_(mutex) WaitingFibers waiting;
389 std::condition_variable added;
390 std::mutex mutex;
391 };
392
393 // https://en.wikipedia.org/wiki/Xorshift
394 class FastRnd {
395 public:
operator()396 inline uint64_t operator()() {
397 x ^= x << 13;
398 x ^= x >> 7;
399 x ^= x << 17;
400 return x;
401 }
402
403 private:
404 uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
405 };
406
407 // The current worker bound to the current thread.
408 static thread_local Worker* current;
409
410 Mode const mode;
411 Scheduler* const scheduler;
412 Allocator::unique_ptr<Fiber> mainFiber;
413 Fiber* currentFiber = nullptr;
414 std::thread thread;
415 Work work;
416 FiberSet idleFibers; // Fibers that have completed which can be reused.
417 std::vector<Allocator::unique_ptr<Fiber>>
418 workerFibers; // All fibers created by this worker.
419 FastRnd rng;
420 std::atomic<bool> shutdown = {false};
421 };
422
423 // stealWork() attempts to steal a task from the worker with the given id.
424 // Returns true if a task was stolen and assigned to out, otherwise false.
425 bool stealWork(Worker* thief, uint64_t from, Task& out);
426
427 // onBeginSpinning() is called when a Worker calls spinForWork().
428 // The scheduler will prioritize this worker for new tasks to try to prevent
429 // it going to sleep.
430 void onBeginSpinning(int workerId);
431
432 // The scheduler currently bound to the current thread.
433 static thread_local Scheduler* bound;
434
435 Allocator* const allocator;
436
437 std::function<void()> threadInitFunc;
438 std::mutex threadInitFuncMutex;
439
440 std::array<std::atomic<int>, 8> spinningWorkers;
441 std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
442
443 // TODO: Make this lot thread-safe so setWorkerThreadCount() can be called
444 // during execution of tasks.
445 std::atomic<unsigned int> nextEnqueueIndex = {0};
446 unsigned int numWorkerThreads = 0;
447 std::array<Worker*, MaxWorkerThreads> workerThreads;
448
449 std::mutex singleThreadedWorkerMutex;
450 std::unordered_map<std::thread::id, Allocator::unique_ptr<Worker>>
451 singleThreadedWorkers;
452 };
453
_Requires_lock_held_(lock)454 _Requires_lock_held_(lock)
455 template <typename Clock, typename Duration>
456 bool Scheduler::Fiber::wait(
457 Lock& lock,
458 const std::chrono::time_point<Clock, Duration>& timeout,
459 const Predicate& pred) {
460 using ToDuration = typename TimePoint::duration;
461 using ToClock = typename TimePoint::clock;
462 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
463 return worker->wait(lock, &tp, pred);
464 }
465
getCurrent()466 Scheduler::Worker* Scheduler::Worker::getCurrent() {
467 return Worker::current;
468 }
469
getCurrentFiber()470 Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
471 return currentFiber;
472 }
473
474 // schedule() schedules the function f to be asynchronously called with the
475 // given arguments using the currently bound scheduler.
476 template <typename Function, typename... Args>
schedule(Function && f,Args &&...args)477 inline void schedule(Function&& f, Args&&... args) {
478 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
479 auto scheduler = Scheduler::get();
480 scheduler->enqueue(
481 std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
482 }
483
484 // schedule() schedules the function f to be asynchronously called using the
485 // currently bound scheduler.
486 template <typename Function>
schedule(Function && f)487 inline void schedule(Function&& f) {
488 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
489 auto scheduler = Scheduler::get();
490 scheduler->enqueue(std::forward<Function>(f));
491 }
492
493 } // namespace marl
494
495 #endif // marl_scheduler_h
496