1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "osfiber.h" // Must come first. See osfiber_ucontext.h.
16
17 #include "marl/scheduler.h"
18
19 #include "marl/debug.h"
20 #include "marl/thread.h"
21 #include "marl/trace.h"
22
23 #if defined(_WIN32)
24 #include <intrin.h> // __nop()
25 #endif
26
27 // Enable to trace scheduler events.
28 #define ENABLE_TRACE_EVENTS 0
29
30 // Enable to print verbose debug logging.
31 #define ENABLE_DEBUG_LOGGING 0
32
33 #if ENABLE_TRACE_EVENTS
34 #define TRACE(...) MARL_SCOPED_EVENT(__VA_ARGS__)
35 #else
36 #define TRACE(...)
37 #endif
38
39 #if ENABLE_DEBUG_LOGGING
40 #define DBG_LOG(msg, ...) \
41 printf("%.3x " msg "\n", (int)threadID() & 0xfff, __VA_ARGS__)
42 #else
43 #define DBG_LOG(msg, ...)
44 #endif
45
46 #define ASSERT_FIBER_STATE(FIBER, STATE) \
47 MARL_ASSERT(FIBER->state == STATE, \
48 "fiber %d was in state %s, but expected %s", (int)FIBER->id, \
49 Fiber::toString(FIBER->state), Fiber::toString(STATE))
50
51 namespace {
52
53 #if ENABLE_DEBUG_LOGGING
54 // threadID() returns a uint64_t representing the currently executing thread.
55 // threadID() is only intended to be used for debugging purposes.
threadID()56 inline uint64_t threadID() {
57 auto id = std::this_thread::get_id();
58 return std::hash<std::thread::id>()(id);
59 }
60 #endif
61
nop()62 inline void nop() {
63 #if defined(_WIN32)
64 __nop();
65 #else
66 __asm__ __volatile__("nop");
67 #endif
68 }
69
setConfigDefaults(const marl::Scheduler::Config & cfgIn)70 inline marl::Scheduler::Config setConfigDefaults(
71 const marl::Scheduler::Config& cfgIn) {
72 marl::Scheduler::Config cfg{cfgIn};
73 if (cfg.workerThread.count > 0 && !cfg.workerThread.affinityPolicy) {
74 cfg.workerThread.affinityPolicy = marl::Thread::Affinity::Policy::anyOf(
75 marl::Thread::Affinity::all(cfg.allocator), cfg.allocator);
76 }
77 return cfg;
78 }
79
80 } // anonymous namespace
81
82 namespace marl {
83
84 ////////////////////////////////////////////////////////////////////////////////
85 // Scheduler
86 ////////////////////////////////////////////////////////////////////////////////
87 MARL_INSTANTIATE_THREAD_LOCAL(Scheduler*, Scheduler::bound, nullptr);
88
get()89 Scheduler* Scheduler::get() {
90 return bound;
91 }
92
setBound(Scheduler * scheduler)93 void Scheduler::setBound(Scheduler* scheduler) {
94 bound = scheduler;
95 }
96
bind()97 void Scheduler::bind() {
98 MARL_ASSERT(get() == nullptr, "Scheduler already bound");
99 setBound(this);
100 {
101 marl::lock lock(singleThreadedWorkers.mutex);
102 auto worker = cfg.allocator->make_unique<Worker>(
103 this, Worker::Mode::SingleThreaded, -1);
104 worker->start();
105 auto tid = std::this_thread::get_id();
106 singleThreadedWorkers.byTid.emplace(tid, std::move(worker));
107 }
108 }
109
unbind()110 void Scheduler::unbind() {
111 MARL_ASSERT(get() != nullptr, "No scheduler bound");
112 auto worker = Worker::getCurrent();
113 worker->stop();
114 {
115 marl::lock lock(get()->singleThreadedWorkers.mutex);
116 auto tid = std::this_thread::get_id();
117 auto it = get()->singleThreadedWorkers.byTid.find(tid);
118 MARL_ASSERT(it != get()->singleThreadedWorkers.byTid.end(),
119 "singleThreadedWorker not found");
120 MARL_ASSERT(it->second.get() == worker, "worker is not bound?");
121 get()->singleThreadedWorkers.byTid.erase(it);
122 if (get()->singleThreadedWorkers.byTid.empty()) {
123 get()->singleThreadedWorkers.unbind.notify_one();
124 }
125 }
126 setBound(nullptr);
127 }
128
Scheduler(const Config & config)129 Scheduler::Scheduler(const Config& config)
130 : cfg(setConfigDefaults(config)),
131 workerThreads{},
132 singleThreadedWorkers(config.allocator) {
133 for (size_t i = 0; i < spinningWorkers.size(); i++) {
134 spinningWorkers[i] = -1;
135 }
136 for (int i = 0; i < cfg.workerThread.count; i++) {
137 workerThreads[i] =
138 cfg.allocator->create<Worker>(this, Worker::Mode::MultiThreaded, i);
139 }
140 for (int i = 0; i < cfg.workerThread.count; i++) {
141 workerThreads[i]->start();
142 }
143 }
144
~Scheduler()145 Scheduler::~Scheduler() {
146 {
147 // Wait until all the single threaded workers have been unbound.
148 marl::lock lock(singleThreadedWorkers.mutex);
149 lock.wait(singleThreadedWorkers.unbind,
150 [this]() REQUIRES(singleThreadedWorkers.mutex) {
151 return singleThreadedWorkers.byTid.empty();
152 });
153 }
154
155 // Release all worker threads.
156 // This will wait for all in-flight tasks to complete before returning.
157 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
158 workerThreads[i]->stop();
159 }
160 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
161 cfg.allocator->destroy(workerThreads[i]);
162 }
163 }
164
enqueue(Task && task)165 void Scheduler::enqueue(Task&& task) {
166 if (task.is(Task::Flags::SameThread)) {
167 Worker::getCurrent()->enqueue(std::move(task));
168 return;
169 }
170 if (cfg.workerThread.count > 0) {
171 while (true) {
172 // Prioritize workers that have recently started spinning.
173 auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
174 auto idx = spinningWorkers[i].exchange(-1);
175 if (idx < 0) {
176 // If a spinning worker couldn't be found, round-robin the
177 // workers.
178 idx = nextEnqueueIndex++ % cfg.workerThread.count;
179 }
180
181 auto worker = workerThreads[idx];
182 if (worker->tryLock()) {
183 worker->enqueueAndUnlock(std::move(task));
184 return;
185 }
186 }
187 } else {
188 if (auto worker = Worker::getCurrent()) {
189 worker->enqueue(std::move(task));
190 } else {
191 MARL_FATAL(
192 "singleThreadedWorker not found. Did you forget to call "
193 "marl::Scheduler::bind()?");
194 }
195 }
196 }
197
config() const198 const Scheduler::Config& Scheduler::config() const {
199 return cfg;
200 }
201
stealWork(Worker * thief,uint64_t from,Task & out)202 bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
203 if (cfg.workerThread.count > 0) {
204 auto thread = workerThreads[from % cfg.workerThread.count];
205 if (thread != thief) {
206 if (thread->steal(out)) {
207 return true;
208 }
209 }
210 }
211 return false;
212 }
213
onBeginSpinning(int workerId)214 void Scheduler::onBeginSpinning(int workerId) {
215 auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
216 spinningWorkers[idx] = workerId;
217 }
218
219 ////////////////////////////////////////////////////////////////////////////////
220 // Scheduler::Config
221 ////////////////////////////////////////////////////////////////////////////////
allCores()222 Scheduler::Config Scheduler::Config::allCores() {
223 return Config().setWorkerThreadCount(Thread::numLogicalCPUs());
224 }
225
226 ////////////////////////////////////////////////////////////////////////////////
227 // Scheduler::Fiber
228 ////////////////////////////////////////////////////////////////////////////////
Fiber(Allocator::unique_ptr<OSFiber> && impl,uint32_t id)229 Scheduler::Fiber::Fiber(Allocator::unique_ptr<OSFiber>&& impl, uint32_t id)
230 : id(id), impl(std::move(impl)), worker(Worker::getCurrent()) {
231 MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
232 }
233
current()234 Scheduler::Fiber* Scheduler::Fiber::current() {
235 auto worker = Worker::getCurrent();
236 return worker != nullptr ? worker->getCurrentFiber() : nullptr;
237 }
238
notify()239 void Scheduler::Fiber::notify() {
240 worker->enqueue(this);
241 }
242
wait(marl::lock & lock,const Predicate & pred)243 void Scheduler::Fiber::wait(marl::lock& lock, const Predicate& pred) {
244 MARL_ASSERT(worker == Worker::getCurrent(),
245 "Scheduler::Fiber::wait() must only be called on the currently "
246 "executing fiber");
247 worker->wait(lock, nullptr, pred);
248 }
249
switchTo(Fiber * to)250 void Scheduler::Fiber::switchTo(Fiber* to) {
251 MARL_ASSERT(worker == Worker::getCurrent(),
252 "Scheduler::Fiber::switchTo() must only be called on the "
253 "currently executing fiber");
254 if (to != this) {
255 impl->switchTo(to->impl.get());
256 }
257 }
258
create(Allocator * allocator,uint32_t id,size_t stackSize,const std::function<void ()> & func)259 Allocator::unique_ptr<Scheduler::Fiber> Scheduler::Fiber::create(
260 Allocator* allocator,
261 uint32_t id,
262 size_t stackSize,
263 const std::function<void()>& func) {
264 return allocator->make_unique<Fiber>(
265 OSFiber::createFiber(allocator, stackSize, func), id);
266 }
267
268 Allocator::unique_ptr<Scheduler::Fiber>
createFromCurrentThread(Allocator * allocator,uint32_t id)269 Scheduler::Fiber::createFromCurrentThread(Allocator* allocator, uint32_t id) {
270 return allocator->make_unique<Fiber>(
271 OSFiber::createFiberFromCurrentThread(allocator), id);
272 }
273
toString(State state)274 const char* Scheduler::Fiber::toString(State state) {
275 switch (state) {
276 case State::Idle:
277 return "Idle";
278 case State::Yielded:
279 return "Yielded";
280 case State::Queued:
281 return "Queued";
282 case State::Running:
283 return "Running";
284 case State::Waiting:
285 return "Waiting";
286 }
287 MARL_ASSERT(false, "bad fiber state");
288 return "<unknown>";
289 }
290
291 ////////////////////////////////////////////////////////////////////////////////
292 // Scheduler::WaitingFibers
293 ////////////////////////////////////////////////////////////////////////////////
WaitingFibers(Allocator * allocator)294 Scheduler::WaitingFibers::WaitingFibers(Allocator* allocator)
295 : timeouts(allocator), fibers(allocator) {}
296
operator bool() const297 Scheduler::WaitingFibers::operator bool() const {
298 return !fibers.empty();
299 }
300
take(const TimePoint & timeout)301 Scheduler::Fiber* Scheduler::WaitingFibers::take(const TimePoint& timeout) {
302 if (!*this) {
303 return nullptr;
304 }
305 auto it = timeouts.begin();
306 if (timeout < it->timepoint) {
307 return nullptr;
308 }
309 auto fiber = it->fiber;
310 timeouts.erase(it);
311 auto deleted = fibers.erase(fiber) != 0;
312 (void)deleted;
313 MARL_ASSERT(deleted, "WaitingFibers::take() maps out of sync");
314 return fiber;
315 }
316
next() const317 Scheduler::TimePoint Scheduler::WaitingFibers::next() const {
318 MARL_ASSERT(*this,
319 "WaitingFibers::next() called when there' no waiting fibers");
320 return timeouts.begin()->timepoint;
321 }
322
add(const TimePoint & timeout,Fiber * fiber)323 void Scheduler::WaitingFibers::add(const TimePoint& timeout, Fiber* fiber) {
324 timeouts.emplace(Timeout{timeout, fiber});
325 bool added = fibers.emplace(fiber, timeout).second;
326 (void)added;
327 MARL_ASSERT(added, "WaitingFibers::add() fiber already waiting");
328 }
329
erase(Fiber * fiber)330 void Scheduler::WaitingFibers::erase(Fiber* fiber) {
331 auto it = fibers.find(fiber);
332 if (it != fibers.end()) {
333 auto timeout = it->second;
334 auto erased = timeouts.erase(Timeout{timeout, fiber}) != 0;
335 (void)erased;
336 MARL_ASSERT(erased, "WaitingFibers::erase() maps out of sync");
337 fibers.erase(it);
338 }
339 }
340
contains(Fiber * fiber) const341 bool Scheduler::WaitingFibers::contains(Fiber* fiber) const {
342 return fibers.count(fiber) != 0;
343 }
344
operator <(const Timeout & o) const345 bool Scheduler::WaitingFibers::Timeout::operator<(const Timeout& o) const {
346 if (timepoint != o.timepoint) {
347 return timepoint < o.timepoint;
348 }
349 return fiber < o.fiber;
350 }
351
352 ////////////////////////////////////////////////////////////////////////////////
353 // Scheduler::Worker
354 ////////////////////////////////////////////////////////////////////////////////
355 MARL_INSTANTIATE_THREAD_LOCAL(Scheduler::Worker*,
356 Scheduler::Worker::current,
357 nullptr);
358
Worker(Scheduler * scheduler,Mode mode,uint32_t id)359 Scheduler::Worker::Worker(Scheduler* scheduler, Mode mode, uint32_t id)
360 : id(id),
361 mode(mode),
362 scheduler(scheduler),
363 work(scheduler->cfg.allocator),
364 idleFibers(scheduler->cfg.allocator) {}
365
start()366 void Scheduler::Worker::start() {
367 switch (mode) {
368 case Mode::MultiThreaded: {
369 auto allocator = scheduler->cfg.allocator;
370 auto& affinityPolicy = scheduler->cfg.workerThread.affinityPolicy;
371 auto affinity = affinityPolicy->get(id, allocator);
372 thread = Thread(std::move(affinity), [=] {
373 Thread::setName("Thread<%.2d>", int(id));
374
375 if (auto const& initFunc = scheduler->cfg.workerThread.initializer) {
376 initFunc(id);
377 }
378
379 Scheduler::setBound(scheduler);
380 Worker::current = this;
381 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
382 currentFiber = mainFiber.get();
383 {
384 marl::lock lock(work.mutex);
385 run();
386 }
387 mainFiber.reset();
388 Worker::current = nullptr;
389 });
390 break;
391 }
392 case Mode::SingleThreaded: {
393 Worker::current = this;
394 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
395 currentFiber = mainFiber.get();
396 break;
397 }
398 default:
399 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
400 }
401 }
402
stop()403 void Scheduler::Worker::stop() {
404 switch (mode) {
405 case Mode::MultiThreaded: {
406 enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
407 thread.join();
408 break;
409 }
410 case Mode::SingleThreaded: {
411 marl::lock lock(work.mutex);
412 shutdown = true;
413 runUntilShutdown();
414 Worker::current = nullptr;
415 break;
416 }
417 default:
418 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
419 }
420 }
421
wait(const TimePoint * timeout)422 bool Scheduler::Worker::wait(const TimePoint* timeout) {
423 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
424 {
425 marl::lock lock(work.mutex);
426 suspend(timeout);
427 }
428 return timeout == nullptr || std::chrono::system_clock::now() < *timeout;
429 }
430
wait(lock & waitLock,const TimePoint * timeout,const Predicate & pred)431 bool Scheduler::Worker::wait(lock& waitLock,
432 const TimePoint* timeout,
433 const Predicate& pred) {
434 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
435 while (!pred()) {
436 // Lock the work mutex to call suspend().
437 work.mutex.lock();
438
439 // Unlock the wait mutex with the work mutex lock held.
440 // Order is important here as we need to ensure that the fiber is not
441 // enqueued (via Fiber::notify()) between the waitLock.unlock() and fiber
442 // switch, otherwise the Fiber::notify() call may be ignored and the fiber
443 // is never woken.
444 waitLock.unlock_no_tsa();
445
446 // suspend the fiber.
447 suspend(timeout);
448
449 // Fiber resumed. We don't need the work mutex locked any more.
450 work.mutex.unlock();
451
452 // Re-lock to either return due to timeout, or call pred().
453 waitLock.lock_no_tsa();
454
455 // Check timeout.
456 if (timeout != nullptr && std::chrono::system_clock::now() >= *timeout) {
457 return false;
458 }
459
460 // Spurious wake up. Spin again.
461 }
462 return true;
463 }
464
suspend(const std::chrono::system_clock::time_point * timeout)465 void Scheduler::Worker::suspend(
466 const std::chrono::system_clock::time_point* timeout) {
467 // Current fiber is yielding as it is blocked.
468 if (timeout != nullptr) {
469 changeFiberState(currentFiber, Fiber::State::Running,
470 Fiber::State::Waiting);
471 work.waiting.add(*timeout, currentFiber);
472 } else {
473 changeFiberState(currentFiber, Fiber::State::Running,
474 Fiber::State::Yielded);
475 }
476
477 // First wait until there's something else this worker can do.
478 waitForWork();
479
480 work.numBlockedFibers++;
481
482 if (!work.fibers.empty()) {
483 // There's another fiber that has become unblocked, resume that.
484 work.num--;
485 auto to = containers::take(work.fibers);
486 ASSERT_FIBER_STATE(to, Fiber::State::Queued);
487 switchToFiber(to);
488 } else if (!idleFibers.empty()) {
489 // There's an old fiber we can reuse, resume that.
490 auto to = containers::take(idleFibers);
491 ASSERT_FIBER_STATE(to, Fiber::State::Idle);
492 switchToFiber(to);
493 } else {
494 // Tasks to process and no existing fibers to resume.
495 // Spawn a new fiber.
496 switchToFiber(createWorkerFiber());
497 }
498
499 work.numBlockedFibers--;
500
501 setFiberState(currentFiber, Fiber::State::Running);
502 }
503
tryLock()504 bool Scheduler::Worker::tryLock() {
505 return work.mutex.try_lock();
506 }
507
enqueue(Fiber * fiber)508 void Scheduler::Worker::enqueue(Fiber* fiber) {
509 bool notify = false;
510 {
511 marl::lock lock(work.mutex);
512 DBG_LOG("%d: ENQUEUE(%d %s)", (int)id, (int)fiber->id,
513 Fiber::toString(fiber->state));
514 switch (fiber->state) {
515 case Fiber::State::Running:
516 case Fiber::State::Queued:
517 return; // Nothing to do here - task is already queued or running.
518 case Fiber::State::Waiting:
519 work.waiting.erase(fiber);
520 break;
521 case Fiber::State::Idle:
522 case Fiber::State::Yielded:
523 break;
524 }
525 notify = work.notifyAdded;
526 work.fibers.push_back(fiber);
527 MARL_ASSERT(!work.waiting.contains(fiber),
528 "fiber is unexpectedly in the waiting list");
529 setFiberState(fiber, Fiber::State::Queued);
530 work.num++;
531 }
532
533 if (notify) {
534 work.added.notify_one();
535 }
536 }
537
enqueue(Task && task)538 void Scheduler::Worker::enqueue(Task&& task) {
539 work.mutex.lock();
540 enqueueAndUnlock(std::move(task));
541 }
542
enqueueAndUnlock(Task && task)543 void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
544 auto notify = work.notifyAdded;
545 work.tasks.push_back(std::move(task));
546 work.num++;
547 work.mutex.unlock();
548 if (notify) {
549 work.added.notify_one();
550 }
551 }
552
steal(Task & out)553 bool Scheduler::Worker::steal(Task& out) {
554 if (work.num.load() == 0) {
555 return false;
556 }
557 if (!work.mutex.try_lock()) {
558 return false;
559 }
560 if (work.tasks.empty() || work.tasks.front().is(Task::Flags::SameThread)) {
561 work.mutex.unlock();
562 return false;
563 }
564 work.num--;
565 out = containers::take(work.tasks);
566 work.mutex.unlock();
567 return true;
568 }
569
run()570 void Scheduler::Worker::run() {
571 if (mode == Mode::MultiThreaded) {
572 MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
573 // This is the entry point for a multi-threaded worker.
574 // Start with a regular condition-variable wait for work. This avoids
575 // starting the thread with a spinForWork().
576 work.wait([this]() REQUIRES(work.mutex) {
577 return work.num > 0 || work.waiting || shutdown;
578 });
579 }
580 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
581 runUntilShutdown();
582 switchToFiber(mainFiber.get());
583 }
584
runUntilShutdown()585 void Scheduler::Worker::runUntilShutdown() {
586 while (!shutdown || work.num > 0 || work.numBlockedFibers > 0U) {
587 waitForWork();
588 runUntilIdle();
589 }
590 }
591
waitForWork()592 void Scheduler::Worker::waitForWork() {
593 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
594 "work.num out of sync");
595 if (work.num > 0) {
596 return;
597 }
598
599 if (mode == Mode::MultiThreaded) {
600 scheduler->onBeginSpinning(id);
601 work.mutex.unlock();
602 spinForWork();
603 work.mutex.lock();
604 }
605
606 work.wait([this]() REQUIRES(work.mutex) {
607 return work.num > 0 || (shutdown && work.numBlockedFibers == 0U);
608 });
609 if (work.waiting) {
610 enqueueFiberTimeouts();
611 }
612 }
613
enqueueFiberTimeouts()614 void Scheduler::Worker::enqueueFiberTimeouts() {
615 auto now = std::chrono::system_clock::now();
616 while (auto fiber = work.waiting.take(now)) {
617 changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued);
618 DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id);
619 work.fibers.push_back(fiber);
620 work.num++;
621 }
622 }
623
changeFiberState(Fiber * fiber,Fiber::State from,Fiber::State to) const624 void Scheduler::Worker::changeFiberState(Fiber* fiber,
625 Fiber::State from,
626 Fiber::State to) const {
627 (void)from; // Unusued parameter when ENABLE_DEBUG_LOGGING is disabled.
628 DBG_LOG("%d: CHANGE_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
629 Fiber::toString(from), Fiber::toString(to));
630 ASSERT_FIBER_STATE(fiber, from);
631 fiber->state = to;
632 }
633
setFiberState(Fiber * fiber,Fiber::State to) const634 void Scheduler::Worker::setFiberState(Fiber* fiber, Fiber::State to) const {
635 DBG_LOG("%d: SET_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
636 Fiber::toString(fiber->state), Fiber::toString(to));
637 fiber->state = to;
638 }
639
spinForWork()640 void Scheduler::Worker::spinForWork() {
641 TRACE("SPIN");
642 Task stolen;
643
644 constexpr auto duration = std::chrono::milliseconds(1);
645 auto start = std::chrono::high_resolution_clock::now();
646 while (std::chrono::high_resolution_clock::now() - start < duration) {
647 for (int i = 0; i < 256; i++) // Empirically picked magic number!
648 {
649 // clang-format off
650 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
651 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
652 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
653 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
654 // clang-format on
655 if (work.num > 0) {
656 return;
657 }
658 }
659
660 if (scheduler->stealWork(this, rng(), stolen)) {
661 marl::lock lock(work.mutex);
662 work.tasks.emplace_back(std::move(stolen));
663 work.num++;
664 return;
665 }
666
667 std::this_thread::yield();
668 }
669 }
670
runUntilIdle()671 void Scheduler::Worker::runUntilIdle() {
672 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
673 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
674 "work.num out of sync");
675 while (!work.fibers.empty() || !work.tasks.empty()) {
676 // Note: we cannot take and store on the stack more than a single fiber
677 // or task at a time, as the Fiber may yield and these items may get
678 // held on suspended fiber stack.
679
680 while (!work.fibers.empty()) {
681 work.num--;
682 auto fiber = containers::take(work.fibers);
683 // Sanity checks,
684 MARL_ASSERT(idleFibers.count(fiber) == 0, "dequeued fiber is idle");
685 MARL_ASSERT(fiber != currentFiber, "dequeued fiber is currently running");
686 ASSERT_FIBER_STATE(fiber, Fiber::State::Queued);
687
688 changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle);
689 auto added = idleFibers.emplace(currentFiber).second;
690 (void)added;
691 MARL_ASSERT(added, "fiber already idle");
692
693 switchToFiber(fiber);
694 changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running);
695 }
696
697 if (!work.tasks.empty()) {
698 work.num--;
699 auto task = containers::take(work.tasks);
700 work.mutex.unlock();
701
702 // Run the task.
703 task();
704
705 // std::function<> can carry arguments with complex destructors.
706 // Ensure these are destructed outside of the lock.
707 task = Task();
708
709 work.mutex.lock();
710 }
711 }
712 }
713
createWorkerFiber()714 Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
715 auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
716 DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId);
717 auto fiber = Fiber::create(scheduler->cfg.allocator, fiberId,
718 scheduler->cfg.fiberStackSize,
719 [&]() REQUIRES(work.mutex) { run(); });
720 auto ptr = fiber.get();
721 workerFibers.emplace_back(std::move(fiber));
722 return ptr;
723 }
724
switchToFiber(Fiber * to)725 void Scheduler::Worker::switchToFiber(Fiber* to) {
726 DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id);
727 MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0,
728 "switching to idle fiber");
729 auto from = currentFiber;
730 currentFiber = to;
731 from->switchTo(to);
732 }
733
734 ////////////////////////////////////////////////////////////////////////////////
735 // Scheduler::Worker::Work
736 ////////////////////////////////////////////////////////////////////////////////
Work(Allocator * allocator)737 Scheduler::Worker::Work::Work(Allocator* allocator)
738 : tasks(allocator), fibers(allocator), waiting(allocator) {}
739
740 template <typename F>
wait(F && f)741 void Scheduler::Worker::Work::wait(F&& f) {
742 notifyAdded = true;
743 if (waiting) {
744 mutex.wait_until_locked(added, waiting.next(), f);
745 } else {
746 mutex.wait_locked(added, f);
747 }
748 notifyAdded = false;
749 }
750
751 ////////////////////////////////////////////////////////////////////////////////
752 // Scheduler::Worker::Work
753 ////////////////////////////////////////////////////////////////////////////////
SingleThreadedWorkers(Allocator * allocator)754 Scheduler::SingleThreadedWorkers::SingleThreadedWorkers(Allocator* allocator)
755 : byTid(allocator) {}
756
757 } // namespace marl
758