1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_condition_variable_h
16 #define marl_condition_variable_h
17
18 #include "debug.h"
19 #include "defer.h"
20 #include "scheduler.h"
21
22 #include <atomic>
23 #include <condition_variable>
24 #include <mutex>
25 #include <unordered_set>
26
27 namespace marl {
28
29 // ConditionVariable is a synchronization primitive that can be used to block
30 // one or more fibers or threads, until another fiber or thread modifies a
31 // shared variable (the condition) and notifies the ConditionVariable.
32 //
33 // If the ConditionVariable is blocked on a thread with a Scheduler bound, the
34 // thread will work on other tasks until the ConditionVariable is unblocked.
35 class ConditionVariable {
36 public:
37 inline ConditionVariable();
38
39 // notify_one() notifies and potentially unblocks one waiting fiber or thread.
40 inline void notify_one();
41
42 // notify_all() notifies and potentially unblocks all waiting fibers and/or
43 // threads.
44 inline void notify_all();
45
46 // wait() blocks the current fiber or thread until the predicate is satisfied
47 // and the ConditionVariable is notified.
48 template <typename Predicate>
49 inline void wait(std::unique_lock<std::mutex>& lock, Predicate&& pred);
50
51 // wait_for() blocks the current fiber or thread until the predicate is
52 // satisfied, and the ConditionVariable is notified, or the timeout has been
53 // reached. Returns false if pred still evaluates to false after the timeout
54 // has been reached, otherwise true.
55 template <typename Rep, typename Period, typename Predicate>
56 bool wait_for(std::unique_lock<std::mutex>& lock,
57 const std::chrono::duration<Rep, Period>& duration,
58 Predicate&& pred);
59
60 // wait_until() blocks the current fiber or thread until the predicate is
61 // satisfied, and the ConditionVariable is notified, or the timeout has been
62 // reached. Returns false if pred still evaluates to false after the timeout
63 // has been reached, otherwise true.
64 template <typename Clock, typename Duration, typename Predicate>
65 bool wait_until(std::unique_lock<std::mutex>& lock,
66 const std::chrono::time_point<Clock, Duration>& timeout,
67 Predicate&& pred);
68
69 private:
70 ConditionVariable(const ConditionVariable&) = delete;
71 ConditionVariable(ConditionVariable&&) = delete;
72 ConditionVariable& operator=(const ConditionVariable&) = delete;
73 ConditionVariable& operator=(ConditionVariable&&) = delete;
74
75 std::mutex mutex;
76 std::unordered_set<Scheduler::Fiber*> waiting;
77 std::condition_variable condition;
78 std::atomic<int> numWaiting = {0};
79 std::atomic<int> numWaitingOnCondition = {0};
80 };
81
ConditionVariable()82 ConditionVariable::ConditionVariable() {}
83
notify_one()84 void ConditionVariable::notify_one() {
85 if (numWaiting == 0) {
86 return;
87 }
88 {
89 std::unique_lock<std::mutex> lock(mutex);
90 for (auto fiber : waiting) {
91 fiber->notify();
92 }
93 }
94 if (numWaitingOnCondition > 0) {
95 condition.notify_one();
96 }
97 }
98
notify_all()99 void ConditionVariable::notify_all() {
100 if (numWaiting == 0) {
101 return;
102 }
103 {
104 std::unique_lock<std::mutex> lock(mutex);
105 for (auto fiber : waiting) {
106 fiber->notify();
107 }
108 }
109 if (numWaitingOnCondition > 0) {
110 condition.notify_all();
111 }
112 }
113
114 template <typename Predicate>
wait(std::unique_lock<std::mutex> & lock,Predicate && pred)115 void ConditionVariable::wait(std::unique_lock<std::mutex>& lock,
116 Predicate&& pred) {
117 if (pred()) {
118 return;
119 }
120 numWaiting++;
121 if (auto fiber = Scheduler::Fiber::current()) {
122 // Currently executing on a scheduler fiber.
123 // Yield to let other tasks run that can unblock this fiber.
124 mutex.lock();
125 waiting.emplace(fiber);
126 mutex.unlock();
127
128 fiber->wait(lock, pred);
129
130 mutex.lock();
131 waiting.erase(fiber);
132 mutex.unlock();
133 } else {
134 // Currently running outside of the scheduler.
135 // Delegate to the std::condition_variable.
136 numWaitingOnCondition++;
137 condition.wait(lock, pred);
138 numWaitingOnCondition--;
139 }
140 numWaiting--;
141 }
142
143 template <typename Rep, typename Period, typename Predicate>
wait_for(std::unique_lock<std::mutex> & lock,const std::chrono::duration<Rep,Period> & duration,Predicate && pred)144 bool ConditionVariable::wait_for(
145 std::unique_lock<std::mutex>& lock,
146 const std::chrono::duration<Rep, Period>& duration,
147 Predicate&& pred) {
148 return wait_until(lock, std::chrono::system_clock::now() + duration, pred);
149 }
150
151 template <typename Clock, typename Duration, typename Predicate>
wait_until(std::unique_lock<std::mutex> & lock,const std::chrono::time_point<Clock,Duration> & timeout,Predicate && pred)152 bool ConditionVariable::wait_until(
153 std::unique_lock<std::mutex>& lock,
154 const std::chrono::time_point<Clock, Duration>& timeout,
155 Predicate&& pred) {
156 if (pred()) {
157 return true;
158 }
159 numWaiting++;
160 defer(numWaiting--);
161
162 if (auto fiber = Scheduler::Fiber::current()) {
163 // Currently executing on a scheduler fiber.
164 // Yield to let other tasks run that can unblock this fiber.
165 mutex.lock();
166 waiting.emplace(fiber);
167 mutex.unlock();
168
169 auto res = fiber->wait(lock, timeout, pred);
170
171 mutex.lock();
172 waiting.erase(fiber);
173 mutex.unlock();
174
175 return res;
176 } else {
177 // Currently running outside of the scheduler.
178 // Delegate to the std::condition_variable.
179 numWaitingOnCondition++;
180 defer(numWaitingOnCondition--);
181 return condition.wait_until(lock, timeout, pred);
182 }
183 }
184
185 } // namespace marl
186
187 #endif // marl_condition_variable_h
188