1 // Copyright 2019 The Marl Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef marl_waitgroup_h 16 #define marl_waitgroup_h 17 18 #include "conditionvariable.h" 19 #include "debug.h" 20 21 #include <atomic> 22 #include <mutex> 23 24 namespace marl { 25 26 // WaitGroup is a synchronization primitive that holds an internal counter that 27 // can incremented, decremented and waited on until it reaches 0. 28 // WaitGroups can be used as a simple mechanism for waiting on a number of 29 // concurrently execute a number of tasks to complete. 30 // 31 // Example: 32 // 33 // void runTasksConcurrently(int numConcurrentTasks) 34 // { 35 // // Construct the WaitGroup with an initial count of numConcurrentTasks. 36 // marl::WaitGroup wg(numConcurrentTasks); 37 // for (int i = 0; i < numConcurrentTasks; i++) 38 // { 39 // // Schedule a task to be run asynchronously. 40 // // These may all be run concurrently. 41 // marl::schedule([=] { 42 // // Once the task has finished, decrement the waitgroup counter 43 // // to signal that this has completed. 44 // defer(wg.done()); 45 // doSomeWork(); 46 // }); 47 // } 48 // // Block until all tasks have completed. 49 // wg.wait(); 50 // } 51 class WaitGroup { 52 public: 53 // Constructs the WaitGroup with the specified initial count. 54 inline WaitGroup(unsigned int initialCount = 0); 55 56 // add() increments the internal counter by count. 57 inline void add(unsigned int count = 1) const; 58 59 // done() decrements the internal counter by one. 60 // Returns true if the internal count has reached zero. 61 inline bool done() const; 62 63 // wait() blocks until the WaitGroup counter reaches zero. 64 inline void wait() const; 65 66 private: 67 struct Data { 68 std::atomic<unsigned int> count = {0}; 69 ConditionVariable condition; 70 std::mutex mutex; 71 }; 72 const std::shared_ptr<Data> data = std::make_shared<Data>(); 73 }; 74 WaitGroup(unsigned int initialCount)75inline WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */) { 76 data->count = initialCount; 77 } 78 add(unsigned int count)79void WaitGroup::add(unsigned int count /* = 1 */) const { 80 data->count += count; 81 } 82 done()83bool WaitGroup::done() const { 84 MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times"); 85 auto count = --data->count; 86 if (count == 0) { 87 std::unique_lock<std::mutex> lock(data->mutex); 88 data->condition.notify_all(); 89 return true; 90 } 91 return false; 92 } 93 wait()94void WaitGroup::wait() const { 95 std::unique_lock<std::mutex> lock(data->mutex); 96 data->condition.wait(lock, [this] { return data->count == 0; }); 97 } 98 99 } // namespace marl 100 101 #endif // marl_waitgroup_h 102