• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_condition_variable_h
16 #define marl_condition_variable_h
17 
18 #include "debug.h"
19 #include "defer.h"
20 #include "scheduler.h"
21 
22 #include <atomic>
23 #include <condition_variable>
24 #include <mutex>
25 #include <unordered_set>
26 
27 namespace marl {
28 
29 // ConditionVariable is a synchronization primitive that can be used to block
30 // one or more fibers or threads, until another fiber or thread modifies a
31 // shared variable (the condition) and notifies the ConditionVariable.
32 //
33 // If the ConditionVariable is blocked on a thread with a Scheduler bound, the
34 // thread will work on other tasks until the ConditionVariable is unblocked.
35 class ConditionVariable {
36  public:
37   inline ConditionVariable();
38 
39   // notify_one() notifies and potentially unblocks one waiting fiber or thread.
40   inline void notify_one();
41 
42   // notify_all() notifies and potentially unblocks all waiting fibers and/or
43   // threads.
44   inline void notify_all();
45 
46   // wait() blocks the current fiber or thread until the predicate is satisfied
47   // and the ConditionVariable is notified.
48   template <typename Predicate>
49   inline void wait(std::unique_lock<std::mutex>& lock, Predicate&& pred);
50 
51   // wait_for() blocks the current fiber or thread until the predicate is
52   // satisfied, and the ConditionVariable is notified, or the timeout has been
53   // reached. Returns false if pred still evaluates to false after the timeout
54   // has been reached, otherwise true.
55   template <typename Rep, typename Period, typename Predicate>
56   bool wait_for(std::unique_lock<std::mutex>& lock,
57                 const std::chrono::duration<Rep, Period>& duration,
58                 Predicate&& pred);
59 
60   // wait_until() blocks the current fiber or thread until the predicate is
61   // satisfied, and the ConditionVariable is notified, or the timeout has been
62   // reached. Returns false if pred still evaluates to false after the timeout
63   // has been reached, otherwise true.
64   template <typename Clock, typename Duration, typename Predicate>
65   bool wait_until(std::unique_lock<std::mutex>& lock,
66                   const std::chrono::time_point<Clock, Duration>& timeout,
67                   Predicate&& pred);
68 
69  private:
70   ConditionVariable(const ConditionVariable&) = delete;
71   ConditionVariable(ConditionVariable&&) = delete;
72   ConditionVariable& operator=(const ConditionVariable&) = delete;
73   ConditionVariable& operator=(ConditionVariable&&) = delete;
74 
75   std::mutex mutex;
76   std::unordered_set<Scheduler::Fiber*> waiting;
77   std::condition_variable condition;
78   std::atomic<int> numWaiting = {0};
79   std::atomic<int> numWaitingOnCondition = {0};
80 };
81 
ConditionVariable()82 ConditionVariable::ConditionVariable() {}
83 
notify_one()84 void ConditionVariable::notify_one() {
85   if (numWaiting == 0) {
86     return;
87   }
88   {
89     std::unique_lock<std::mutex> lock(mutex);
90     for (auto fiber : waiting) {
91       fiber->notify();
92     }
93   }
94   if (numWaitingOnCondition > 0) {
95     condition.notify_one();
96   }
97 }
98 
notify_all()99 void ConditionVariable::notify_all() {
100   if (numWaiting == 0) {
101     return;
102   }
103   {
104     std::unique_lock<std::mutex> lock(mutex);
105     for (auto fiber : waiting) {
106       fiber->notify();
107     }
108   }
109   if (numWaitingOnCondition > 0) {
110     condition.notify_all();
111   }
112 }
113 
114 template <typename Predicate>
wait(std::unique_lock<std::mutex> & lock,Predicate && pred)115 void ConditionVariable::wait(std::unique_lock<std::mutex>& lock,
116                              Predicate&& pred) {
117   if (pred()) {
118     return;
119   }
120   numWaiting++;
121   if (auto fiber = Scheduler::Fiber::current()) {
122     // Currently executing on a scheduler fiber.
123     // Yield to let other tasks run that can unblock this fiber.
124     mutex.lock();
125     waiting.emplace(fiber);
126     mutex.unlock();
127 
128     fiber->wait(lock, pred);
129 
130     mutex.lock();
131     waiting.erase(fiber);
132     mutex.unlock();
133   } else {
134     // Currently running outside of the scheduler.
135     // Delegate to the std::condition_variable.
136     numWaitingOnCondition++;
137     condition.wait(lock, pred);
138     numWaitingOnCondition--;
139   }
140   numWaiting--;
141 }
142 
143 template <typename Rep, typename Period, typename Predicate>
wait_for(std::unique_lock<std::mutex> & lock,const std::chrono::duration<Rep,Period> & duration,Predicate && pred)144 bool ConditionVariable::wait_for(
145     std::unique_lock<std::mutex>& lock,
146     const std::chrono::duration<Rep, Period>& duration,
147     Predicate&& pred) {
148   return wait_until(lock, std::chrono::system_clock::now() + duration, pred);
149 }
150 
151 template <typename Clock, typename Duration, typename Predicate>
wait_until(std::unique_lock<std::mutex> & lock,const std::chrono::time_point<Clock,Duration> & timeout,Predicate && pred)152 bool ConditionVariable::wait_until(
153     std::unique_lock<std::mutex>& lock,
154     const std::chrono::time_point<Clock, Duration>& timeout,
155     Predicate&& pred) {
156   if (pred()) {
157     return true;
158   }
159   numWaiting++;
160   defer(numWaiting--);
161 
162   if (auto fiber = Scheduler::Fiber::current()) {
163     // Currently executing on a scheduler fiber.
164     // Yield to let other tasks run that can unblock this fiber.
165     mutex.lock();
166     waiting.emplace(fiber);
167     mutex.unlock();
168 
169     auto res = fiber->wait(lock, timeout, pred);
170 
171     mutex.lock();
172     waiting.erase(fiber);
173     mutex.unlock();
174 
175     return res;
176   } else {
177     // Currently running outside of the scheduler.
178     // Delegate to the std::condition_variable.
179     numWaitingOnCondition++;
180     defer(numWaitingOnCondition--);
181     return condition.wait_until(lock, timeout, pred);
182   }
183 }
184 
185 }  // namespace marl
186 
187 #endif  // marl_condition_variable_h
188