• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
6 #define BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
7 
8 #include <memory>
9 #include <type_traits>
10 #include <vector>
11 
12 #include "base/functional/bind.h"
13 #include "base/functional/callback.h"
14 #include "base/location.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/sequence_checker.h"
17 #include "base/task/bind_post_task.h"
18 #include "base/task/sequenced_task_runner.h"
19 
20 // OVERVIEW:
21 //
22 // ConcurrentCallbacks<T> is an alternative to BarrierCallback<T>, it dispenses
23 // OnceCallbacks via CreateCallback() and invokes the callback passed to Done()
24 // after all prior callbacks have been run.
25 //
26 // ConcurrentCallbacks<T> is intended to be used over BarrierCallback<T> in
27 // cases where the count is unknown prior to requiring a callback to start a
28 // task, and for cases where the count is manually derived from the code and
29 // subject to human error.
30 //
31 // IMPORTANT NOTES:
32 //
33 // - ConcurrentCallbacks<T> is NOT thread safe.
34 // - The done callback will NOT be run synchronously, it will be PostTask() to
35 //   the sequence that Done() was invoked on.
36 // - ConcurrentCallbacks<T> cannot be used after Done() is called, a CHECK
37 //   verifies this.
38 //
39 // TYPICAL USAGE:
40 //
41 // class Example {
42 //   void OnRequestsReceived(std::vector<Request> requests) {
43 //     base::ConcurrentCallbacks<Result> concurrent;
44 //
45 //     for (Request& request : requests) {
46 //       if (IsValidRequest(request)) {
47 //         StartRequest(std::move(request), concurrent.CreateCallback());
48 //       }
49 //     }
50 //
51 //     std::move(concurrent).Done(
52 //         base::BindOnce(&Example::OnRequestsComplete, GetWeakPtr()));
53 //   }
54 //
55 //   void StartRequest(Request request,
56 //                     base::OnceCallback<void(Result)> callback) {
57 //     // Process the request asynchronously and call callback with a Result.
58 //   }
59 //
60 //   void OnRequestsComplete(std::vector<Result> results) {
61 //     // Invoked after all requests are completed and receives the results of
62 //     // all of them.
63 //   }
64 // };
65 
66 namespace base {
67 
68 template <typename T>
69 class ConcurrentCallbacks {
70  public:
71   using Results = std::vector<std::remove_cvref_t<T>>;
72 
ConcurrentCallbacks()73   ConcurrentCallbacks() {
74     auto info_owner = std::make_unique<Info>();
75     info_ = info_owner.get();
76     info_run_callback_ = BindRepeating(&Info::Run, std::move(info_owner));
77   }
78 
79   // Create a callback for the done callback to wait for.
CreateCallback()80   [[nodiscard]] OnceCallback<void(T)> CreateCallback() {
81     CHECK(info_);
82     DCHECK_CALLED_ON_VALID_SEQUENCE(info_->sequence_checker_);
83     ++info_->pending_;
84     return info_run_callback_;
85   }
86 
87   // Finish creating concurrent callbacks and provide done callback to run once
88   // all prior callbacks have executed.
89   // `this` is no longer usable after calling Done(), must be called with
90   // std::move().
91   void Done(OnceCallback<void(Results)> done_callback,
92             const Location& location = FROM_HERE) && {
93     CHECK(info_);
94     DCHECK_CALLED_ON_VALID_SEQUENCE(info_->sequence_checker_);
95     info_->done_callback_ =
96         BindPostTask(SequencedTaskRunner::GetCurrentDefault(),
97                      std::move(done_callback), location);
98     if (info_->pending_ == 0u) {
99       std::move(info_->done_callback_).Run(std::move(info_->results_));
100     }
101     info_ = nullptr;
102   }
103 
104  private:
105   class Info {
106    public:
107     Info() = default;
108 
Run(T value)109     void Run(T value) {
110       DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
111       CHECK_GT(pending_, 0u);
112       --pending_;
113       results_.push_back(std::move(value));
114       if (done_callback_ && pending_ == 0u) {
115         std::move(done_callback_).Run(std::move(results_));
116       }
117     }
118 
119     size_t pending_ GUARDED_BY_CONTEXT(sequence_checker_) = 0u;
120     Results results_ GUARDED_BY_CONTEXT(sequence_checker_);
121     OnceCallback<void(Results)> done_callback_
122         GUARDED_BY_CONTEXT(sequence_checker_);
123     SEQUENCE_CHECKER(sequence_checker_);
124   };
125 
126   RepeatingCallback<void(T)> info_run_callback_;
127   // info_ is owned by info_run_callback_.
128   raw_ptr<Info> info_;
129 };
130 
131 }  // namespace base
132 
133 #endif  // BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
134