1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "osfiber.h" // Must come first. See osfiber_ucontext.h.
16
17 #include "marl/scheduler.h"
18
19 #include "marl/debug.h"
20 #include "marl/sanitizers.h"
21 #include "marl/thread.h"
22 #include "marl/trace.h"
23
24 #if defined(_WIN32)
25 #include <intrin.h> // __nop()
26 #endif
27
28 // Enable to trace scheduler events.
29 #define ENABLE_TRACE_EVENTS 0
30
31 // Enable to print verbose debug logging.
32 #define ENABLE_DEBUG_LOGGING 0
33
34 #if ENABLE_TRACE_EVENTS
35 #define TRACE(...) MARL_SCOPED_EVENT(__VA_ARGS__)
36 #else
37 #define TRACE(...)
38 #endif
39
40 #if ENABLE_DEBUG_LOGGING
41 #define DBG_LOG(msg, ...) \
42 printf("%.3x " msg "\n", (int)threadID() & 0xfff, __VA_ARGS__)
43 #else
44 #define DBG_LOG(msg, ...)
45 #endif
46
47 #define ASSERT_FIBER_STATE(FIBER, STATE) \
48 MARL_ASSERT(FIBER->state == STATE, \
49 "fiber %d was in state %s, but expected %s", (int)FIBER->id, \
50 Fiber::toString(FIBER->state), Fiber::toString(STATE))
51
52 namespace {
53
54 #if ENABLE_DEBUG_LOGGING
55 // threadID() returns a uint64_t representing the currently executing thread.
56 // threadID() is only intended to be used for debugging purposes.
threadID()57 inline uint64_t threadID() {
58 auto id = std::this_thread::get_id();
59 return std::hash<std::thread::id>()(id);
60 }
61 #endif
62
nop()63 inline void nop() {
64 #if defined(_WIN32)
65 __nop();
66 #else
67 __asm__ __volatile__("nop");
68 #endif
69 }
70
setConfigDefaults(const marl::Scheduler::Config & cfgIn)71 inline marl::Scheduler::Config setConfigDefaults(
72 const marl::Scheduler::Config& cfgIn) {
73 marl::Scheduler::Config cfg{cfgIn};
74 if (cfg.workerThread.count > 0 && !cfg.workerThread.affinityPolicy) {
75 cfg.workerThread.affinityPolicy = marl::Thread::Affinity::Policy::anyOf(
76 marl::Thread::Affinity::all(cfg.allocator), cfg.allocator);
77 }
78 return cfg;
79 }
80
81 } // anonymous namespace
82
83 namespace marl {
84
85 ////////////////////////////////////////////////////////////////////////////////
86 // Scheduler
87 ////////////////////////////////////////////////////////////////////////////////
88 thread_local Scheduler* Scheduler::bound = nullptr;
89
get()90 Scheduler* Scheduler::get() {
91 return bound;
92 }
93
bind()94 void Scheduler::bind() {
95 #if !MEMORY_SANITIZER_ENABLED
96 // thread_local variables in shared libraries are initialized at load-time,
97 // but this is not observed by MemorySanitizer if the loader itself was not
98 // instrumented, leading to false-positive unitialized variable errors.
99 // See https://github.com/google/marl/issues/184
100 MARL_ASSERT(bound == nullptr, "Scheduler already bound");
101 #endif
102 bound = this;
103 {
104 marl::lock lock(singleThreadedWorkers.mutex);
105 auto worker = cfg.allocator->make_unique<Worker>(
106 this, Worker::Mode::SingleThreaded, -1);
107 worker->start();
108 auto tid = std::this_thread::get_id();
109 singleThreadedWorkers.byTid.emplace(tid, std::move(worker));
110 }
111 }
112
unbind()113 void Scheduler::unbind() {
114 MARL_ASSERT(bound != nullptr, "No scheduler bound");
115 auto worker = Worker::getCurrent();
116 worker->stop();
117 {
118 marl::lock lock(bound->singleThreadedWorkers.mutex);
119 auto tid = std::this_thread::get_id();
120 auto it = bound->singleThreadedWorkers.byTid.find(tid);
121 MARL_ASSERT(it != bound->singleThreadedWorkers.byTid.end(),
122 "singleThreadedWorker not found");
123 MARL_ASSERT(it->second.get() == worker, "worker is not bound?");
124 bound->singleThreadedWorkers.byTid.erase(it);
125 if (bound->singleThreadedWorkers.byTid.empty()) {
126 bound->singleThreadedWorkers.unbind.notify_one();
127 }
128 }
129 bound = nullptr;
130 }
131
Scheduler(const Config & config)132 Scheduler::Scheduler(const Config& config)
133 : cfg(setConfigDefaults(config)),
134 workerThreads{},
135 singleThreadedWorkers(config.allocator) {
136 for (size_t i = 0; i < spinningWorkers.size(); i++) {
137 spinningWorkers[i] = -1;
138 }
139 for (int i = 0; i < cfg.workerThread.count; i++) {
140 workerThreads[i] =
141 cfg.allocator->create<Worker>(this, Worker::Mode::MultiThreaded, i);
142 }
143 for (int i = 0; i < cfg.workerThread.count; i++) {
144 workerThreads[i]->start();
145 }
146 }
147
~Scheduler()148 Scheduler::~Scheduler() {
149 {
150 // Wait until all the single threaded workers have been unbound.
151 marl::lock lock(singleThreadedWorkers.mutex);
152 lock.wait(singleThreadedWorkers.unbind,
153 [this]() REQUIRES(singleThreadedWorkers.mutex) {
154 return singleThreadedWorkers.byTid.empty();
155 });
156 }
157
158 // Release all worker threads.
159 // This will wait for all in-flight tasks to complete before returning.
160 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
161 workerThreads[i]->stop();
162 }
163 for (int i = cfg.workerThread.count - 1; i >= 0; i--) {
164 cfg.allocator->destroy(workerThreads[i]);
165 }
166 }
167
enqueue(Task && task)168 void Scheduler::enqueue(Task&& task) {
169 if (task.is(Task::Flags::SameThread)) {
170 Worker::getCurrent()->enqueue(std::move(task));
171 return;
172 }
173 if (cfg.workerThread.count > 0) {
174 while (true) {
175 // Prioritize workers that have recently started spinning.
176 auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
177 auto idx = spinningWorkers[i].exchange(-1);
178 if (idx < 0) {
179 // If a spinning worker couldn't be found, round-robin the
180 // workers.
181 idx = nextEnqueueIndex++ % cfg.workerThread.count;
182 }
183
184 auto worker = workerThreads[idx];
185 if (worker->tryLock()) {
186 worker->enqueueAndUnlock(std::move(task));
187 return;
188 }
189 }
190 } else {
191 if (auto worker = Worker::getCurrent()) {
192 worker->enqueue(std::move(task));
193 } else {
194 MARL_FATAL(
195 "singleThreadedWorker not found. Did you forget to call "
196 "marl::Scheduler::bind()?");
197 }
198 }
199 }
200
config() const201 const Scheduler::Config& Scheduler::config() const {
202 return cfg;
203 }
204
stealWork(Worker * thief,uint64_t from,Task & out)205 bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
206 if (cfg.workerThread.count > 0) {
207 auto thread = workerThreads[from % cfg.workerThread.count];
208 if (thread != thief) {
209 if (thread->steal(out)) {
210 return true;
211 }
212 }
213 }
214 return false;
215 }
216
onBeginSpinning(int workerId)217 void Scheduler::onBeginSpinning(int workerId) {
218 auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
219 spinningWorkers[idx] = workerId;
220 }
221
222 ////////////////////////////////////////////////////////////////////////////////
223 // Scheduler::Config
224 ////////////////////////////////////////////////////////////////////////////////
allCores()225 Scheduler::Config Scheduler::Config::allCores() {
226 return Config().setWorkerThreadCount(Thread::numLogicalCPUs());
227 }
228
229 ////////////////////////////////////////////////////////////////////////////////
230 // Scheduler::Fiber
231 ////////////////////////////////////////////////////////////////////////////////
Fiber(Allocator::unique_ptr<OSFiber> && impl,uint32_t id)232 Scheduler::Fiber::Fiber(Allocator::unique_ptr<OSFiber>&& impl, uint32_t id)
233 : id(id), impl(std::move(impl)), worker(Worker::getCurrent()) {
234 MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
235 }
236
current()237 Scheduler::Fiber* Scheduler::Fiber::current() {
238 auto worker = Worker::getCurrent();
239 return worker != nullptr ? worker->getCurrentFiber() : nullptr;
240 }
241
notify()242 void Scheduler::Fiber::notify() {
243 worker->enqueue(this);
244 }
245
wait(marl::lock & lock,const Predicate & pred)246 void Scheduler::Fiber::wait(marl::lock& lock, const Predicate& pred) {
247 MARL_ASSERT(worker == Worker::getCurrent(),
248 "Scheduler::Fiber::wait() must only be called on the currently "
249 "executing fiber");
250 worker->wait(lock, nullptr, pred);
251 }
252
switchTo(Fiber * to)253 void Scheduler::Fiber::switchTo(Fiber* to) {
254 MARL_ASSERT(worker == Worker::getCurrent(),
255 "Scheduler::Fiber::switchTo() must only be called on the "
256 "currently executing fiber");
257 if (to != this) {
258 impl->switchTo(to->impl.get());
259 }
260 }
261
create(Allocator * allocator,uint32_t id,size_t stackSize,const std::function<void ()> & func)262 Allocator::unique_ptr<Scheduler::Fiber> Scheduler::Fiber::create(
263 Allocator* allocator,
264 uint32_t id,
265 size_t stackSize,
266 const std::function<void()>& func) {
267 return allocator->make_unique<Fiber>(
268 OSFiber::createFiber(allocator, stackSize, func), id);
269 }
270
271 Allocator::unique_ptr<Scheduler::Fiber>
createFromCurrentThread(Allocator * allocator,uint32_t id)272 Scheduler::Fiber::createFromCurrentThread(Allocator* allocator, uint32_t id) {
273 return allocator->make_unique<Fiber>(
274 OSFiber::createFiberFromCurrentThread(allocator), id);
275 }
276
toString(State state)277 const char* Scheduler::Fiber::toString(State state) {
278 switch (state) {
279 case State::Idle:
280 return "Idle";
281 case State::Yielded:
282 return "Yielded";
283 case State::Queued:
284 return "Queued";
285 case State::Running:
286 return "Running";
287 case State::Waiting:
288 return "Waiting";
289 }
290 MARL_ASSERT(false, "bad fiber state");
291 return "<unknown>";
292 }
293
294 ////////////////////////////////////////////////////////////////////////////////
295 // Scheduler::WaitingFibers
296 ////////////////////////////////////////////////////////////////////////////////
WaitingFibers(Allocator * allocator)297 Scheduler::WaitingFibers::WaitingFibers(Allocator* allocator)
298 : timeouts(allocator), fibers(allocator) {}
299
operator bool() const300 Scheduler::WaitingFibers::operator bool() const {
301 return !fibers.empty();
302 }
303
take(const TimePoint & timeout)304 Scheduler::Fiber* Scheduler::WaitingFibers::take(const TimePoint& timeout) {
305 if (!*this) {
306 return nullptr;
307 }
308 auto it = timeouts.begin();
309 if (timeout < it->timepoint) {
310 return nullptr;
311 }
312 auto fiber = it->fiber;
313 timeouts.erase(it);
314 auto deleted = fibers.erase(fiber) != 0;
315 (void)deleted;
316 MARL_ASSERT(deleted, "WaitingFibers::take() maps out of sync");
317 return fiber;
318 }
319
next() const320 Scheduler::TimePoint Scheduler::WaitingFibers::next() const {
321 MARL_ASSERT(*this,
322 "WaitingFibers::next() called when there' no waiting fibers");
323 return timeouts.begin()->timepoint;
324 }
325
add(const TimePoint & timeout,Fiber * fiber)326 void Scheduler::WaitingFibers::add(const TimePoint& timeout, Fiber* fiber) {
327 timeouts.emplace(Timeout{timeout, fiber});
328 bool added = fibers.emplace(fiber, timeout).second;
329 (void)added;
330 MARL_ASSERT(added, "WaitingFibers::add() fiber already waiting");
331 }
332
erase(Fiber * fiber)333 void Scheduler::WaitingFibers::erase(Fiber* fiber) {
334 auto it = fibers.find(fiber);
335 if (it != fibers.end()) {
336 auto timeout = it->second;
337 auto erased = timeouts.erase(Timeout{timeout, fiber}) != 0;
338 (void)erased;
339 MARL_ASSERT(erased, "WaitingFibers::erase() maps out of sync");
340 fibers.erase(it);
341 }
342 }
343
contains(Fiber * fiber) const344 bool Scheduler::WaitingFibers::contains(Fiber* fiber) const {
345 return fibers.count(fiber) != 0;
346 }
347
operator <(const Timeout & o) const348 bool Scheduler::WaitingFibers::Timeout::operator<(const Timeout& o) const {
349 if (timepoint != o.timepoint) {
350 return timepoint < o.timepoint;
351 }
352 return fiber < o.fiber;
353 }
354
355 ////////////////////////////////////////////////////////////////////////////////
356 // Scheduler::Worker
357 ////////////////////////////////////////////////////////////////////////////////
358 thread_local Scheduler::Worker* Scheduler::Worker::current = nullptr;
359
Worker(Scheduler * scheduler,Mode mode,uint32_t id)360 Scheduler::Worker::Worker(Scheduler* scheduler, Mode mode, uint32_t id)
361 : id(id),
362 mode(mode),
363 scheduler(scheduler),
364 work(scheduler->cfg.allocator),
365 idleFibers(scheduler->cfg.allocator) {}
366
start()367 void Scheduler::Worker::start() {
368 switch (mode) {
369 case Mode::MultiThreaded: {
370 auto allocator = scheduler->cfg.allocator;
371 auto& affinityPolicy = scheduler->cfg.workerThread.affinityPolicy;
372 auto affinity = affinityPolicy->get(id, allocator);
373 thread = Thread(std::move(affinity), [=] {
374 Thread::setName("Thread<%.2d>", int(id));
375
376 if (auto const& initFunc = scheduler->cfg.workerThread.initializer) {
377 initFunc(id);
378 }
379
380 Scheduler::bound = scheduler;
381 Worker::current = this;
382 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
383 currentFiber = mainFiber.get();
384 {
385 marl::lock lock(work.mutex);
386 run();
387 }
388 mainFiber.reset();
389 Worker::current = nullptr;
390 });
391 break;
392 }
393 case Mode::SingleThreaded: {
394 Worker::current = this;
395 mainFiber = Fiber::createFromCurrentThread(scheduler->cfg.allocator, 0);
396 currentFiber = mainFiber.get();
397 break;
398 }
399 default:
400 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
401 }
402 }
403
stop()404 void Scheduler::Worker::stop() {
405 switch (mode) {
406 case Mode::MultiThreaded: {
407 enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
408 thread.join();
409 break;
410 }
411 case Mode::SingleThreaded: {
412 marl::lock lock(work.mutex);
413 shutdown = true;
414 runUntilShutdown();
415 Worker::current = nullptr;
416 break;
417 }
418 default:
419 MARL_ASSERT(false, "Unknown mode: %d", int(mode));
420 }
421 }
422
wait(const TimePoint * timeout)423 bool Scheduler::Worker::wait(const TimePoint* timeout) {
424 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
425 {
426 marl::lock lock(work.mutex);
427 suspend(timeout);
428 }
429 return timeout == nullptr || std::chrono::system_clock::now() < *timeout;
430 }
431
wait(lock & waitLock,const TimePoint * timeout,const Predicate & pred)432 bool Scheduler::Worker::wait(lock& waitLock,
433 const TimePoint* timeout,
434 const Predicate& pred) {
435 DBG_LOG("%d: WAIT(%d)", (int)id, (int)currentFiber->id);
436 while (!pred()) {
437 // Lock the work mutex to call suspend().
438 work.mutex.lock();
439
440 // Unlock the wait mutex with the work mutex lock held.
441 // Order is important here as we need to ensure that the fiber is not
442 // enqueued (via Fiber::notify()) between the waitLock.unlock() and fiber
443 // switch, otherwise the Fiber::notify() call may be ignored and the fiber
444 // is never woken.
445 waitLock.unlock_no_tsa();
446
447 // suspend the fiber.
448 suspend(timeout);
449
450 // Fiber resumed. We don't need the work mutex locked any more.
451 work.mutex.unlock();
452
453 // Re-lock to either return due to timeout, or call pred().
454 waitLock.lock_no_tsa();
455
456 // Check timeout.
457 if (timeout != nullptr && std::chrono::system_clock::now() >= *timeout) {
458 return false;
459 }
460
461 // Spurious wake up. Spin again.
462 }
463 return true;
464 }
465
suspend(const std::chrono::system_clock::time_point * timeout)466 void Scheduler::Worker::suspend(
467 const std::chrono::system_clock::time_point* timeout) {
468 // Current fiber is yielding as it is blocked.
469 if (timeout != nullptr) {
470 changeFiberState(currentFiber, Fiber::State::Running,
471 Fiber::State::Waiting);
472 work.waiting.add(*timeout, currentFiber);
473 } else {
474 changeFiberState(currentFiber, Fiber::State::Running,
475 Fiber::State::Yielded);
476 }
477
478 // First wait until there's something else this worker can do.
479 waitForWork();
480
481 work.numBlockedFibers++;
482
483 if (!work.fibers.empty()) {
484 // There's another fiber that has become unblocked, resume that.
485 work.num--;
486 auto to = containers::take(work.fibers);
487 ASSERT_FIBER_STATE(to, Fiber::State::Queued);
488 switchToFiber(to);
489 } else if (!idleFibers.empty()) {
490 // There's an old fiber we can reuse, resume that.
491 auto to = containers::take(idleFibers);
492 ASSERT_FIBER_STATE(to, Fiber::State::Idle);
493 switchToFiber(to);
494 } else {
495 // Tasks to process and no existing fibers to resume.
496 // Spawn a new fiber.
497 switchToFiber(createWorkerFiber());
498 }
499
500 work.numBlockedFibers--;
501
502 setFiberState(currentFiber, Fiber::State::Running);
503 }
504
tryLock()505 bool Scheduler::Worker::tryLock() {
506 return work.mutex.try_lock();
507 }
508
enqueue(Fiber * fiber)509 void Scheduler::Worker::enqueue(Fiber* fiber) {
510 bool notify = false;
511 {
512 marl::lock lock(work.mutex);
513 DBG_LOG("%d: ENQUEUE(%d %s)", (int)id, (int)fiber->id,
514 Fiber::toString(fiber->state));
515 switch (fiber->state) {
516 case Fiber::State::Running:
517 case Fiber::State::Queued:
518 return; // Nothing to do here - task is already queued or running.
519 case Fiber::State::Waiting:
520 work.waiting.erase(fiber);
521 break;
522 case Fiber::State::Idle:
523 case Fiber::State::Yielded:
524 break;
525 }
526 notify = work.notifyAdded;
527 work.fibers.push_back(fiber);
528 MARL_ASSERT(!work.waiting.contains(fiber),
529 "fiber is unexpectedly in the waiting list");
530 setFiberState(fiber, Fiber::State::Queued);
531 work.num++;
532 }
533
534 if (notify) {
535 work.added.notify_one();
536 }
537 }
538
enqueue(Task && task)539 void Scheduler::Worker::enqueue(Task&& task) {
540 work.mutex.lock();
541 enqueueAndUnlock(std::move(task));
542 }
543
enqueueAndUnlock(Task && task)544 void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
545 auto notify = work.notifyAdded;
546 work.tasks.push_back(std::move(task));
547 work.num++;
548 work.mutex.unlock();
549 if (notify) {
550 work.added.notify_one();
551 }
552 }
553
steal(Task & out)554 bool Scheduler::Worker::steal(Task& out) {
555 if (work.num.load() == 0) {
556 return false;
557 }
558 if (!work.mutex.try_lock()) {
559 return false;
560 }
561 if (work.tasks.empty() || work.tasks.front().is(Task::Flags::SameThread)) {
562 work.mutex.unlock();
563 return false;
564 }
565 work.num--;
566 out = containers::take(work.tasks);
567 work.mutex.unlock();
568 return true;
569 }
570
run()571 void Scheduler::Worker::run() {
572 if (mode == Mode::MultiThreaded) {
573 MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
574 // This is the entry point for a multi-threaded worker.
575 // Start with a regular condition-variable wait for work. This avoids
576 // starting the thread with a spinForWork().
577 work.wait([this]() REQUIRES(work.mutex) {
578 return work.num > 0 || work.waiting || shutdown;
579 });
580 }
581 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
582 runUntilShutdown();
583 switchToFiber(mainFiber.get());
584 }
585
runUntilShutdown()586 void Scheduler::Worker::runUntilShutdown() {
587 while (!shutdown || work.num > 0 || work.numBlockedFibers > 0U) {
588 waitForWork();
589 runUntilIdle();
590 }
591 }
592
waitForWork()593 void Scheduler::Worker::waitForWork() {
594 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
595 "work.num out of sync");
596 if (work.num > 0) {
597 return;
598 }
599
600 if (mode == Mode::MultiThreaded) {
601 scheduler->onBeginSpinning(id);
602 work.mutex.unlock();
603 spinForWork();
604 work.mutex.lock();
605 }
606
607 work.wait([this]() REQUIRES(work.mutex) {
608 return work.num > 0 || (shutdown && work.numBlockedFibers == 0U);
609 });
610 if (work.waiting) {
611 enqueueFiberTimeouts();
612 }
613 }
614
enqueueFiberTimeouts()615 void Scheduler::Worker::enqueueFiberTimeouts() {
616 auto now = std::chrono::system_clock::now();
617 while (auto fiber = work.waiting.take(now)) {
618 changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued);
619 DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id);
620 work.fibers.push_back(fiber);
621 work.num++;
622 }
623 }
624
changeFiberState(Fiber * fiber,Fiber::State from,Fiber::State to) const625 void Scheduler::Worker::changeFiberState(Fiber* fiber,
626 Fiber::State from,
627 Fiber::State to) const {
628 (void)from; // Unusued parameter when ENABLE_DEBUG_LOGGING is disabled.
629 DBG_LOG("%d: CHANGE_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
630 Fiber::toString(from), Fiber::toString(to));
631 ASSERT_FIBER_STATE(fiber, from);
632 fiber->state = to;
633 }
634
setFiberState(Fiber * fiber,Fiber::State to) const635 void Scheduler::Worker::setFiberState(Fiber* fiber, Fiber::State to) const {
636 DBG_LOG("%d: SET_FIBER_STATE(%d %s -> %s)", (int)id, (int)fiber->id,
637 Fiber::toString(fiber->state), Fiber::toString(to));
638 fiber->state = to;
639 }
640
spinForWork()641 void Scheduler::Worker::spinForWork() {
642 TRACE("SPIN");
643 Task stolen;
644
645 constexpr auto duration = std::chrono::milliseconds(1);
646 auto start = std::chrono::high_resolution_clock::now();
647 while (std::chrono::high_resolution_clock::now() - start < duration) {
648 for (int i = 0; i < 256; i++) // Empirically picked magic number!
649 {
650 // clang-format off
651 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
652 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
653 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
654 nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
655 // clang-format on
656 if (work.num > 0) {
657 return;
658 }
659 }
660
661 if (scheduler->stealWork(this, rng(), stolen)) {
662 marl::lock lock(work.mutex);
663 work.tasks.emplace_back(std::move(stolen));
664 work.num++;
665 return;
666 }
667
668 std::this_thread::yield();
669 }
670 }
671
runUntilIdle()672 void Scheduler::Worker::runUntilIdle() {
673 ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
674 MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
675 "work.num out of sync");
676 while (!work.fibers.empty() || !work.tasks.empty()) {
677 // Note: we cannot take and store on the stack more than a single fiber
678 // or task at a time, as the Fiber may yield and these items may get
679 // held on suspended fiber stack.
680
681 while (!work.fibers.empty()) {
682 work.num--;
683 auto fiber = containers::take(work.fibers);
684 // Sanity checks,
685 MARL_ASSERT(idleFibers.count(fiber) == 0, "dequeued fiber is idle");
686 MARL_ASSERT(fiber != currentFiber, "dequeued fiber is currently running");
687 ASSERT_FIBER_STATE(fiber, Fiber::State::Queued);
688
689 changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle);
690 auto added = idleFibers.emplace(currentFiber).second;
691 (void)added;
692 MARL_ASSERT(added, "fiber already idle");
693
694 switchToFiber(fiber);
695 changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running);
696 }
697
698 if (!work.tasks.empty()) {
699 work.num--;
700 auto task = containers::take(work.tasks);
701 work.mutex.unlock();
702
703 // Run the task.
704 task();
705
706 // std::function<> can carry arguments with complex destructors.
707 // Ensure these are destructed outside of the lock.
708 task = Task();
709
710 work.mutex.lock();
711 }
712 }
713 }
714
createWorkerFiber()715 Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
716 auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
717 DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId);
718 auto fiber = Fiber::create(scheduler->cfg.allocator, fiberId,
719 scheduler->cfg.fiberStackSize,
720 [&]() REQUIRES(work.mutex) { run(); });
721 auto ptr = fiber.get();
722 workerFibers.emplace_back(std::move(fiber));
723 return ptr;
724 }
725
switchToFiber(Fiber * to)726 void Scheduler::Worker::switchToFiber(Fiber* to) {
727 DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id);
728 MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0,
729 "switching to idle fiber");
730 auto from = currentFiber;
731 currentFiber = to;
732 from->switchTo(to);
733 }
734
735 ////////////////////////////////////////////////////////////////////////////////
736 // Scheduler::Worker::Work
737 ////////////////////////////////////////////////////////////////////////////////
Work(Allocator * allocator)738 Scheduler::Worker::Work::Work(Allocator* allocator)
739 : tasks(allocator), fibers(allocator), waiting(allocator) {}
740
741 template <typename F>
wait(F && f)742 void Scheduler::Worker::Work::wait(F&& f) {
743 notifyAdded = true;
744 if (waiting) {
745 mutex.wait_until_locked(added, waiting.next(), f);
746 } else {
747 mutex.wait_locked(added, f);
748 }
749 notifyAdded = false;
750 }
751
752 ////////////////////////////////////////////////////////////////////////////////
753 // Scheduler::Worker::Work
754 ////////////////////////////////////////////////////////////////////////////////
SingleThreadedWorkers(Allocator * allocator)755 Scheduler::SingleThreadedWorkers::SingleThreadedWorkers(Allocator* allocator)
756 : byTid(allocator) {}
757
758 } // namespace marl
759