1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ 17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ 18 19 #include <atomic> 20 #include <functional> 21 22 namespace tensorflow { 23 24 class InternalIncrementalBarrier; 25 26 // BarrierClosure (see 27 // https://github.com/chromium/chromium/blob/master/base/barrier_closure.h) 28 // executes a callback after it has been invoked |num_closures| times. 29 // Plus, `BarrierClosure` is a continuation-passing style abstraction and self- 30 // deleting. 31 32 // IncrementalBarrier is a convenience class to be used in place of a barrier 33 // closure, which is particularly helpful (e.g. simplify code) because callers 34 // don't need to calculate the |num_closures| beforehand. 35 // 36 // Example Usage: 37 // void MakeCalls() { 38 // typedef std::function<void()> Callback; 39 // typedef std::function<void(Callback)> OtherCallback; 40 // Callback done_callback = ... 41 // OtherCallback cb1 = ... 42 // OtherCallback cb2 = ... 43 // std::thread threads[2]; 44 // { 45 // IncrementalBarrier barrier(done_callback); 46 // threads[0] = std::thread(cb1(barrier.Inc()); 47 // threads[1] = std::thread(cb2(barrier.Inc()); 48 // ... at this moment, `barrier` is incremented twice, and then 49 // destructed.... 50 // } 51 // threads[0].join(); 52 // threads[1].join(); 53 // } 54 // 55 // `done_callback` will be called when both conditions are true: 56 // 1) after `barrier` is destructed. 57 // 2) Each `BarrierCallback` returned by `Inc` is called. 58 // This class is thread-safe. 59 class IncrementalBarrier { 60 public: 61 typedef std::function<void()> DoneCallback; 62 typedef std::function<void()> BarrierCallback; 63 explicit IncrementalBarrier(DoneCallback callback); 64 65 ~IncrementalBarrier(); 66 67 // Returns a BarrierCallback (std::function) that individual task call to 68 // signal its completeness. 69 // The returned BarrierCallback outlives this `IncrementalBarrier` instance. 70 // Furthermore, each task should eventually call the returned function, or 71 // else done_callback wouldn't be called. 72 BarrierCallback Inc(); 73 74 private: 75 // self-deleting, thereby not owned by 'IncrementalBarrier'. 76 InternalIncrementalBarrier* internal_barrier_; 77 }; 78 79 } // namespace tensorflow 80 81 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ 82