1 // Copyright 2024 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_ 6 #define BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_ 7 8 #include <memory> 9 #include <type_traits> 10 #include <vector> 11 12 #include "base/functional/bind.h" 13 #include "base/functional/callback.h" 14 #include "base/location.h" 15 #include "base/memory/raw_ptr.h" 16 #include "base/sequence_checker.h" 17 #include "base/task/bind_post_task.h" 18 #include "base/task/sequenced_task_runner.h" 19 20 // OVERVIEW: 21 // 22 // ConcurrentCallbacks<T> is an alternative to BarrierCallback<T>, it dispenses 23 // OnceCallbacks via CreateCallback() and invokes the callback passed to Done() 24 // after all prior callbacks have been run. 25 // 26 // ConcurrentCallbacks<T> is intended to be used over BarrierCallback<T> in 27 // cases where the count is unknown prior to requiring a callback to start a 28 // task, and for cases where the count is manually derived from the code and 29 // subject to human error. 30 // 31 // IMPORTANT NOTES: 32 // 33 // - ConcurrentCallbacks<T> is NOT thread safe. 34 // - The done callback will NOT be run synchronously, it will be PostTask() to 35 // the sequence that Done() was invoked on. 36 // - ConcurrentCallbacks<T> cannot be used after Done() is called, a CHECK 37 // verifies this. 38 // 39 // TYPICAL USAGE: 40 // 41 // class Example { 42 // void OnRequestsReceived(std::vector<Request> requests) { 43 // base::ConcurrentCallbacks<Result> concurrent; 44 // 45 // for (Request& request : requests) { 46 // if (IsValidRequest(request)) { 47 // StartRequest(std::move(request), concurrent.CreateCallback()); 48 // } 49 // } 50 // 51 // std::move(concurrent).Done( 52 // base::BindOnce(&Example::OnRequestsComplete, GetWeakPtr())); 53 // } 54 // 55 // void StartRequest(Request request, 56 // base::OnceCallback<void(Result)> callback) { 57 // // Process the request asynchronously and call callback with a Result. 58 // } 59 // 60 // void OnRequestsComplete(std::vector<Result> results) { 61 // // Invoked after all requests are completed and receives the results of 62 // // all of them. 63 // } 64 // }; 65 66 namespace base { 67 68 template <typename T> 69 class ConcurrentCallbacks { 70 public: 71 using Results = std::vector<std::remove_cvref_t<T>>; 72 ConcurrentCallbacks()73 ConcurrentCallbacks() { 74 auto info_owner = std::make_unique<Info>(); 75 info_ = info_owner.get(); 76 info_run_callback_ = BindRepeating(&Info::Run, std::move(info_owner)); 77 } 78 79 // Create a callback for the done callback to wait for. CreateCallback()80 [[nodiscard]] OnceCallback<void(T)> CreateCallback() { 81 CHECK(info_); 82 DCHECK_CALLED_ON_VALID_SEQUENCE(info_->sequence_checker_); 83 ++info_->pending_; 84 return info_run_callback_; 85 } 86 87 // Finish creating concurrent callbacks and provide done callback to run once 88 // all prior callbacks have executed. 89 // `this` is no longer usable after calling Done(), must be called with 90 // std::move(). 91 void Done(OnceCallback<void(Results)> done_callback, 92 const Location& location = FROM_HERE) && { 93 CHECK(info_); 94 DCHECK_CALLED_ON_VALID_SEQUENCE(info_->sequence_checker_); 95 info_->done_callback_ = 96 BindPostTask(SequencedTaskRunner::GetCurrentDefault(), 97 std::move(done_callback), location); 98 if (info_->pending_ == 0u) { 99 std::move(info_->done_callback_).Run(std::move(info_->results_)); 100 } 101 info_ = nullptr; 102 } 103 104 private: 105 class Info { 106 public: 107 Info() = default; 108 Run(T value)109 void Run(T value) { 110 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 111 CHECK_GT(pending_, 0u); 112 --pending_; 113 results_.push_back(std::move(value)); 114 if (done_callback_ && pending_ == 0u) { 115 std::move(done_callback_).Run(std::move(results_)); 116 } 117 } 118 119 size_t pending_ GUARDED_BY_CONTEXT(sequence_checker_) = 0u; 120 Results results_ GUARDED_BY_CONTEXT(sequence_checker_); 121 OnceCallback<void(Results)> done_callback_ 122 GUARDED_BY_CONTEXT(sequence_checker_); 123 SEQUENCE_CHECKER(sequence_checker_); 124 }; 125 126 RepeatingCallback<void(T)> info_run_callback_; 127 // info_ is owned by info_run_callback_. 128 raw_ptr<Info> info_; 129 }; 130 131 } // namespace base 132 133 #endif // BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_ 134