1 // Copyright 2019 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/task/thread_pool/job_task_source.h"
6
7 #include <type_traits>
8 #include <utility>
9
10 #include "base/bits.h"
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/functional/callback_helpers.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/notreached.h"
16 #include "base/task/common/checked_lock.h"
17 #include "base/task/task_features.h"
18 #include "base/task/thread_pool/pooled_task_runner_delegate.h"
19 #include "base/template_util.h"
20 #include "base/threading/thread_restrictions.h"
21 #include "base/time/time.h"
22 #include "base/time/time_override.h"
23 #include "base/trace_event/base_tracing.h"
24
25 namespace base {
26 namespace internal {
27
28 namespace {
29
30 // Capped to allow assigning task_ids from a bitfield.
31 constexpr size_t kMaxWorkersPerJob = 32;
32 static_assert(
33 kMaxWorkersPerJob <=
34 std::numeric_limits<
35 std::invoke_result<decltype(&JobDelegate::GetTaskId),
36 JobDelegate>::type>::max(),
37 "AcquireTaskId return type isn't big enough to fit kMaxWorkersPerJob");
38
39 } // namespace
40
41 JobTaskSource::State::State() = default;
42 JobTaskSource::State::~State() = default;
43
Cancel()44 JobTaskSource::State::Value JobTaskSource::State::Cancel() {
45 return {value_.fetch_or(kCanceledMask, std::memory_order_relaxed)};
46 }
47
DecrementWorkerCount()48 JobTaskSource::State::Value JobTaskSource::State::DecrementWorkerCount() {
49 const uint32_t value_before_sub =
50 value_.fetch_sub(kWorkerCountIncrement, std::memory_order_relaxed);
51 DCHECK((value_before_sub >> kWorkerCountBitOffset) > 0);
52 return {value_before_sub};
53 }
54
IncrementWorkerCount()55 JobTaskSource::State::Value JobTaskSource::State::IncrementWorkerCount() {
56 uint32_t value_before_add =
57 value_.fetch_add(kWorkerCountIncrement, std::memory_order_relaxed);
58 // The worker count must not overflow a uint8_t.
59 DCHECK((value_before_add >> kWorkerCountBitOffset) < ((1 << 8) - 1));
60 return {value_before_add};
61 }
62
Load() const63 JobTaskSource::State::Value JobTaskSource::State::Load() const {
64 return {value_.load(std::memory_order_relaxed)};
65 }
66
67 JobTaskSource::JoinFlag::JoinFlag() = default;
68 JobTaskSource::JoinFlag::~JoinFlag() = default;
69
Reset()70 void JobTaskSource::JoinFlag::Reset() {
71 value_.store(kNotWaiting, std::memory_order_relaxed);
72 }
73
SetWaiting()74 void JobTaskSource::JoinFlag::SetWaiting() {
75 value_.store(kWaitingForWorkerToYield, std::memory_order_relaxed);
76 }
77
ShouldWorkerYield()78 bool JobTaskSource::JoinFlag::ShouldWorkerYield() {
79 // The fetch_and() sets the state to kWaitingForWorkerToSignal if it was
80 // previously kWaitingForWorkerToYield, otherwise it leaves it unchanged.
81 return value_.fetch_and(kWaitingForWorkerToSignal,
82 std::memory_order_relaxed) ==
83 kWaitingForWorkerToYield;
84 }
85
ShouldWorkerSignal()86 bool JobTaskSource::JoinFlag::ShouldWorkerSignal() {
87 return value_.exchange(kNotWaiting, std::memory_order_relaxed) != kNotWaiting;
88 }
89
JobTaskSource(const Location & from_here,const TaskTraits & traits,RepeatingCallback<void (JobDelegate *)> worker_task,MaxConcurrencyCallback max_concurrency_callback,PooledTaskRunnerDelegate * delegate)90 JobTaskSource::JobTaskSource(const Location& from_here,
91 const TaskTraits& traits,
92 RepeatingCallback<void(JobDelegate*)> worker_task,
93 MaxConcurrencyCallback max_concurrency_callback,
94 PooledTaskRunnerDelegate* delegate)
95 : TaskSource(traits, nullptr, TaskSourceExecutionMode::kJob),
96 from_here_(from_here),
97 max_concurrency_callback_(std::move(max_concurrency_callback)),
98 worker_task_(std::move(worker_task)),
99 primary_task_(base::BindRepeating(
100 [](JobTaskSource* self) {
101 CheckedLock::AssertNoLockHeldOnCurrentThread();
102 // Each worker task has its own delegate with associated state.
103 JobDelegate job_delegate{self, self->delegate_};
104 self->worker_task_.Run(&job_delegate);
105 },
106 base::Unretained(this))),
107 ready_time_(TimeTicks::Now()),
108 delegate_(delegate) {
109 DCHECK(delegate_);
110 }
111
~JobTaskSource()112 JobTaskSource::~JobTaskSource() {
113 // Make sure there's no outstanding active run operation left.
114 DCHECK_EQ(state_.Load().worker_count(), 0U);
115 }
116
GetExecutionEnvironment()117 ExecutionEnvironment JobTaskSource::GetExecutionEnvironment() {
118 return {SequenceToken::Create(), nullptr};
119 }
120
WillJoin()121 bool JobTaskSource::WillJoin() {
122 TRACE_EVENT0("base", "Job.WaitForParticipationOpportunity");
123 CheckedAutoLock auto_lock(worker_lock_);
124 DCHECK(!worker_released_condition_); // This may only be called once.
125 worker_released_condition_ = worker_lock_.CreateConditionVariable();
126 // Prevent wait from triggering a ScopedBlockingCall as this would cause
127 // |ThreadGroup::lock_| to be acquired, causing lock inversion.
128 worker_released_condition_->declare_only_used_while_idle();
129 const auto state_before_add = state_.IncrementWorkerCount();
130
131 if (!state_before_add.is_canceled() &&
132 state_before_add.worker_count() <
133 GetMaxConcurrency(state_before_add.worker_count())) {
134 return true;
135 }
136 return WaitForParticipationOpportunity();
137 }
138
RunJoinTask()139 bool JobTaskSource::RunJoinTask() {
140 JobDelegate job_delegate{this, nullptr};
141 worker_task_.Run(&job_delegate);
142
143 // It is safe to read |state_| without a lock since this variable is atomic
144 // and the call to GetMaxConcurrency() is used for a best effort early exit.
145 // Stale values will only cause WaitForParticipationOpportunity() to be
146 // called.
147 const auto state = TS_UNCHECKED_READ(state_).Load();
148 // The condition is slightly different from the one in WillJoin() since we're
149 // using |state| that was already incremented to include the joining thread.
150 if (!state.is_canceled() &&
151 state.worker_count() <= GetMaxConcurrency(state.worker_count() - 1)) {
152 return true;
153 }
154
155 TRACE_EVENT0("base", "Job.WaitForParticipationOpportunity");
156 CheckedAutoLock auto_lock(worker_lock_);
157 return WaitForParticipationOpportunity();
158 }
159
Cancel(TaskSource::Transaction * transaction)160 void JobTaskSource::Cancel(TaskSource::Transaction* transaction) {
161 // Sets the kCanceledMask bit on |state_| so that further calls to
162 // WillRunTask() never succeed. std::memory_order_relaxed without a lock is
163 // safe because this task source never needs to be re-enqueued after Cancel().
164 TS_UNCHECKED_READ(state_).Cancel();
165 }
166
167 // EXCLUSIVE_LOCK_REQUIRED(worker_lock_)
WaitForParticipationOpportunity()168 bool JobTaskSource::WaitForParticipationOpportunity() {
169 DCHECK(!join_flag_.IsWaiting());
170
171 // std::memory_order_relaxed is sufficient because no other state is
172 // synchronized with |state_| outside of |lock_|.
173 auto state = state_.Load();
174 // |worker_count - 1| to exclude the joining thread which is not active.
175 size_t max_concurrency = GetMaxConcurrency(state.worker_count() - 1);
176
177 // Wait until either:
178 // A) |worker_count| is below or equal to max concurrency and state is not
179 // canceled.
180 // B) All other workers returned and |worker_count| is 1.
181 while (!((state.worker_count() <= max_concurrency && !state.is_canceled()) ||
182 state.worker_count() == 1)) {
183 // std::memory_order_relaxed is sufficient because no other state is
184 // synchronized with |join_flag_| outside of |lock_|.
185 join_flag_.SetWaiting();
186
187 // To avoid unnecessarily waiting, if either condition A) or B) change
188 // |lock_| is taken and |worker_released_condition_| signaled if necessary:
189 // 1- In DidProcessTask(), after worker count is decremented.
190 // 2- In NotifyConcurrencyIncrease(), following a max_concurrency increase.
191 worker_released_condition_->Wait();
192 state = state_.Load();
193 // |worker_count - 1| to exclude the joining thread which is not active.
194 max_concurrency = GetMaxConcurrency(state.worker_count() - 1);
195 }
196 // It's possible though unlikely that the joining thread got a participation
197 // opportunity without a worker signaling.
198 join_flag_.Reset();
199
200 // Case A:
201 if (state.worker_count() <= max_concurrency && !state.is_canceled())
202 return true;
203 // Case B:
204 // Only the joining thread remains.
205 DCHECK_EQ(state.worker_count(), 1U);
206 DCHECK(state.is_canceled() || max_concurrency == 0U);
207 state_.DecrementWorkerCount();
208 // Prevent subsequent accesses to user callbacks.
209 state_.Cancel();
210 return false;
211 }
212
WillRunTask()213 TaskSource::RunStatus JobTaskSource::WillRunTask() {
214 CheckedAutoLock auto_lock(worker_lock_);
215 auto state_before_add = state_.Load();
216
217 // Don't allow this worker to run the task if either:
218 // A) |state_| was canceled.
219 // B) |worker_count| is already at |max_concurrency|.
220 // C) |max_concurrency| was lowered below or to |worker_count|.
221 // Case A:
222 if (state_before_add.is_canceled())
223 return RunStatus::kDisallowed;
224
225 const size_t max_concurrency =
226 GetMaxConcurrency(state_before_add.worker_count());
227 if (state_before_add.worker_count() < max_concurrency)
228 state_before_add = state_.IncrementWorkerCount();
229 const size_t worker_count_before_add = state_before_add.worker_count();
230 // Case B) or C):
231 if (worker_count_before_add >= max_concurrency)
232 return RunStatus::kDisallowed;
233
234 DCHECK_LT(worker_count_before_add, max_concurrency);
235 return max_concurrency == worker_count_before_add + 1
236 ? RunStatus::kAllowedSaturated
237 : RunStatus::kAllowedNotSaturated;
238 }
239
GetRemainingConcurrency() const240 size_t JobTaskSource::GetRemainingConcurrency() const {
241 // It is safe to read |state_| without a lock since this variable is atomic,
242 // and no other state is synchronized with GetRemainingConcurrency().
243 const auto state = TS_UNCHECKED_READ(state_).Load();
244 if (state.is_canceled())
245 return 0;
246 const size_t max_concurrency = GetMaxConcurrency(state.worker_count());
247 // Avoid underflows.
248 if (state.worker_count() > max_concurrency)
249 return 0;
250 return max_concurrency - state.worker_count();
251 }
252
IsActive() const253 bool JobTaskSource::IsActive() const {
254 CheckedAutoLock auto_lock(worker_lock_);
255 auto state = state_.Load();
256 return GetMaxConcurrency(state.worker_count()) != 0 ||
257 state.worker_count() != 0;
258 }
259
GetWorkerCount() const260 size_t JobTaskSource::GetWorkerCount() const {
261 return TS_UNCHECKED_READ(state_).Load().worker_count();
262 }
263
NotifyConcurrencyIncrease()264 void JobTaskSource::NotifyConcurrencyIncrease() {
265 // Avoid unnecessary locks when NotifyConcurrencyIncrease() is spuriously
266 // called.
267 if (GetRemainingConcurrency() == 0)
268 return;
269
270 {
271 // Lock is taken to access |join_flag_| below and signal
272 // |worker_released_condition_|.
273 CheckedAutoLock auto_lock(worker_lock_);
274 if (join_flag_.ShouldWorkerSignal())
275 worker_released_condition_->Signal();
276 }
277
278 // Make sure the task source is in the queue if not already.
279 // Caveat: it's possible but unlikely that the task source has already reached
280 // its intended concurrency and doesn't need to be enqueued if there
281 // previously were too many worker. For simplicity, the task source is always
282 // enqueued and will get discarded if already saturated when it is popped from
283 // the priority queue.
284 delegate_->EnqueueJobTaskSource(this);
285 }
286
GetMaxConcurrency() const287 size_t JobTaskSource::GetMaxConcurrency() const {
288 return GetMaxConcurrency(TS_UNCHECKED_READ(state_).Load().worker_count());
289 }
290
GetMaxConcurrency(size_t worker_count) const291 size_t JobTaskSource::GetMaxConcurrency(size_t worker_count) const {
292 return std::min(max_concurrency_callback_.Run(worker_count),
293 kMaxWorkersPerJob);
294 }
295
AcquireTaskId()296 uint8_t JobTaskSource::AcquireTaskId() {
297 static_assert(kMaxWorkersPerJob <= sizeof(assigned_task_ids_) * 8,
298 "TaskId bitfield isn't big enough to fit kMaxWorkersPerJob.");
299 uint32_t assigned_task_ids =
300 assigned_task_ids_.load(std::memory_order_relaxed);
301 uint32_t new_assigned_task_ids = 0;
302 int task_id = 0;
303 // memory_order_acquire on success, matched with memory_order_release in
304 // ReleaseTaskId() so that operations done by previous threads that had
305 // the same task_id become visible to the current thread.
306 do {
307 // Count trailing one bits. This is the id of the right-most 0-bit in
308 // |assigned_task_ids|.
309 task_id = bits::CountTrailingZeroBits(~assigned_task_ids);
310 new_assigned_task_ids = assigned_task_ids | (uint32_t(1) << task_id);
311 } while (!assigned_task_ids_.compare_exchange_weak(
312 assigned_task_ids, new_assigned_task_ids, std::memory_order_acquire,
313 std::memory_order_relaxed));
314 return static_cast<uint8_t>(task_id);
315 }
316
ReleaseTaskId(uint8_t task_id)317 void JobTaskSource::ReleaseTaskId(uint8_t task_id) {
318 // memory_order_release to match AcquireTaskId().
319 uint32_t previous_task_ids = assigned_task_ids_.fetch_and(
320 ~(uint32_t(1) << task_id), std::memory_order_release);
321 DCHECK(previous_task_ids & (uint32_t(1) << task_id));
322 }
323
ShouldYield()324 bool JobTaskSource::ShouldYield() {
325 // It is safe to read |join_flag_| and |state_| without a lock since these
326 // variables are atomic, keeping in mind that threads may not immediately see
327 // the new value when it is updated.
328 return TS_UNCHECKED_READ(join_flag_).ShouldWorkerYield() ||
329 TS_UNCHECKED_READ(state_).Load().is_canceled();
330 }
331
TakeTask(TaskSource::Transaction * transaction)332 Task JobTaskSource::TakeTask(TaskSource::Transaction* transaction) {
333 // JobTaskSource members are not lock-protected so no need to acquire a lock
334 // if |transaction| is nullptr.
335 DCHECK_GT(TS_UNCHECKED_READ(state_).Load().worker_count(), 0U);
336 DCHECK(primary_task_);
337 return Task(from_here_, primary_task_, TimeTicks(), TimeDelta());
338 }
339
DidProcessTask(TaskSource::Transaction *)340 bool JobTaskSource::DidProcessTask(TaskSource::Transaction* /*transaction*/) {
341 // Lock is needed to access |join_flag_| below and signal
342 // |worker_released_condition_|.
343 CheckedAutoLock auto_lock(worker_lock_);
344 const auto state_before_sub = state_.DecrementWorkerCount();
345
346 if (join_flag_.ShouldWorkerSignal())
347 worker_released_condition_->Signal();
348
349 // A canceled task source should never get re-enqueued.
350 if (state_before_sub.is_canceled())
351 return false;
352
353 DCHECK_GT(state_before_sub.worker_count(), 0U);
354
355 // Re-enqueue the TaskSource if the task ran and the worker count is below the
356 // max concurrency.
357 // |worker_count - 1| to exclude the returning thread.
358 return state_before_sub.worker_count() <=
359 GetMaxConcurrency(state_before_sub.worker_count() - 1);
360 }
361
362 // This is a no-op and should always return true.
WillReEnqueue(TimeTicks now,TaskSource::Transaction *)363 bool JobTaskSource::WillReEnqueue(TimeTicks now,
364 TaskSource::Transaction* /*transaction*/) {
365 return true;
366 }
367
368 // This is a no-op.
OnBecomeReady()369 bool JobTaskSource::OnBecomeReady() {
370 return false;
371 }
372
GetSortKey() const373 TaskSourceSortKey JobTaskSource::GetSortKey() const {
374 return TaskSourceSortKey(priority_racy(), ready_time_,
375 TS_UNCHECKED_READ(state_).Load().worker_count());
376 }
377
378 // This function isn't expected to be called since a job is never delayed.
379 // However, the class still needs to provide an override.
GetDelayedSortKey() const380 TimeTicks JobTaskSource::GetDelayedSortKey() const {
381 return TimeTicks();
382 }
383
384 // This function isn't expected to be called since a job is never delayed.
385 // However, the class still needs to provide an override.
HasReadyTasks(TimeTicks now) const386 bool JobTaskSource::HasReadyTasks(TimeTicks now) const {
387 NOTREACHED();
388 return true;
389 }
390
Clear(TaskSource::Transaction * transaction)391 Task JobTaskSource::Clear(TaskSource::Transaction* transaction) {
392 Cancel();
393 // Nothing is cleared since other workers might still racily run tasks. For
394 // simplicity, the destructor will take care of it once all references are
395 // released.
396 return Task(from_here_, DoNothing(), TimeTicks(), TimeDelta());
397 }
398
399 } // namespace internal
400 } // namespace base
401