• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef marl_waitgroup_h
16 #define marl_waitgroup_h
17 
18 #include "conditionvariable.h"
19 #include "debug.h"
20 
21 #include <atomic>
22 #include <mutex>
23 
24 namespace marl {
25 
26 // WaitGroup is a synchronization primitive that holds an internal counter that
27 // can incremented, decremented and waited on until it reaches 0.
28 // WaitGroups can be used as a simple mechanism for waiting on a number of
29 // concurrently execute a number of tasks to complete.
30 //
31 // Example:
32 //
33 //  void runTasksConcurrently(int numConcurrentTasks)
34 //  {
35 //      // Construct the WaitGroup with an initial count of numConcurrentTasks.
36 //      marl::WaitGroup wg(numConcurrentTasks);
37 //      for (int i = 0; i < numConcurrentTasks; i++)
38 //      {
39 //          // Schedule a task to be run asynchronously.
40 //          // These may all be run concurrently.
41 //          marl::schedule([=] {
42 //              // Once the task has finished, decrement the waitgroup counter
43 //              // to signal that this has completed.
44 //              defer(wg.done());
45 //              doSomeWork();
46 //          });
47 //      }
48 //      // Block until all tasks have completed.
49 //      wg.wait();
50 //  }
51 class WaitGroup {
52  public:
53   // Constructs the WaitGroup with the specified initial count.
54   inline WaitGroup(unsigned int initialCount = 0);
55 
56   // add() increments the internal counter by count.
57   inline void add(unsigned int count = 1) const;
58 
59   // done() decrements the internal counter by one.
60   // Returns true if the internal count has reached zero.
61   inline bool done() const;
62 
63   // wait() blocks until the WaitGroup counter reaches zero.
64   inline void wait() const;
65 
66  private:
67   struct Data {
68     std::atomic<unsigned int> count = {0};
69     ConditionVariable condition;
70     std::mutex mutex;
71   };
72   const std::shared_ptr<Data> data = std::make_shared<Data>();
73 };
74 
WaitGroup(unsigned int initialCount)75 inline WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */) {
76   data->count = initialCount;
77 }
78 
add(unsigned int count)79 void WaitGroup::add(unsigned int count /* = 1 */) const {
80   data->count += count;
81 }
82 
done()83 bool WaitGroup::done() const {
84   MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times");
85   auto count = --data->count;
86   if (count == 0) {
87     std::unique_lock<std::mutex> lock(data->mutex);
88     data->condition.notify_all();
89     return true;
90   }
91   return false;
92 }
93 
wait()94 void WaitGroup::wait() const {
95   std::unique_lock<std::mutex> lock(data->mutex);
96   data->condition.wait(lock, [this] { return data->count == 0; });
97 }
98 
99 }  // namespace marl
100 
101 #endif  // marl_waitgroup_h
102