1 // Copyright 2019 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef marl_waitgroup_h
16 #define marl_waitgroup_h
17
18 #include "conditionvariable.h"
19 #include "debug.h"
20
21 #include <atomic>
22 #include <mutex>
23
24 namespace marl {
25
26 // WaitGroup is a synchronization primitive that holds an internal counter that
27 // can incremented, decremented and waited on until it reaches 0.
28 // WaitGroups can be used as a simple mechanism for waiting on a number of
29 // concurrently execute a number of tasks to complete.
30 //
31 // Example:
32 //
33 // void runTasksConcurrently(int numConcurrentTasks)
34 // {
35 // // Construct the WaitGroup with an initial count of numConcurrentTasks.
36 // marl::WaitGroup wg(numConcurrentTasks);
37 // for (int i = 0; i < numConcurrentTasks; i++)
38 // {
39 // // Schedule a task to be run asynchronously.
40 // // These may all be run concurrently.
41 // marl::schedule([=] {
42 // // Once the task has finished, decrement the waitgroup counter
43 // // to signal that this has completed.
44 // defer(wg.done());
45 // doSomeWork();
46 // });
47 // }
48 // // Block until all tasks have completed.
49 // wg.wait();
50 // }
51 class WaitGroup {
52 public:
53 // Constructs the WaitGroup with the specified initial count.
54 MARL_NO_EXPORT inline WaitGroup(unsigned int initialCount = 0,
55 Allocator* allocator = Allocator::Default);
56
57 // add() increments the internal counter by count.
58 MARL_NO_EXPORT inline void add(unsigned int count = 1) const;
59
60 // done() decrements the internal counter by one.
61 // Returns true if the internal count has reached zero.
62 MARL_NO_EXPORT inline bool done() const;
63
64 // wait() blocks until the WaitGroup counter reaches zero.
65 MARL_NO_EXPORT inline void wait() const;
66
67 private:
68 struct Data {
69 MARL_NO_EXPORT inline Data(Allocator* allocator);
70
71 std::atomic<unsigned int> count = {0};
72 ConditionVariable cv;
73 marl::mutex mutex;
74 };
75 const std::shared_ptr<Data> data;
76 };
77
Data(Allocator * allocator)78 WaitGroup::Data::Data(Allocator* allocator) : cv(allocator) {}
79
WaitGroup(unsigned int initialCount,Allocator * allocator)80 WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */,
81 Allocator* allocator /* = Allocator::Default */)
82 : data(std::make_shared<Data>(allocator)) {
83 data->count = initialCount;
84 }
85
add(unsigned int count)86 void WaitGroup::add(unsigned int count /* = 1 */) const {
87 data->count += count;
88 }
89
done()90 bool WaitGroup::done() const {
91 MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times");
92 auto count = --data->count;
93 if (count == 0) {
94 marl::lock lock(data->mutex);
95 data->cv.notify_all();
96 return true;
97 }
98 return false;
99 }
100
wait()101 void WaitGroup::wait() const {
102 marl::lock lock(data->mutex);
103 data->cv.wait(lock, [this] { return data->count == 0; });
104 }
105
106 } // namespace marl
107
108 #endif // marl_waitgroup_h
109