• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/task/thread_pool/job_task_source.h"
6 
7 #include <bit>
8 #include <limits>
9 #include <type_traits>
10 
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/functional/callback_helpers.h"
14 #include "base/notreached.h"
15 #include "base/task/common/checked_lock.h"
16 #include "base/task/thread_pool/pooled_task_runner_delegate.h"
17 #include "base/threading/thread_restrictions.h"
18 #include "base/time/time.h"
19 #include "base/trace_event/base_tracing.h"
20 
21 namespace base::internal {
22 
23 namespace {
24 
25 // Capped to allow assigning task_ids from a bitfield.
26 constexpr size_t kMaxWorkersPerJob = 32;
27 static_assert(
28     kMaxWorkersPerJob <=
29         std::numeric_limits<
30             std::invoke_result<decltype(&JobDelegate::GetTaskId),
31                                JobDelegate>::type>::max(),
32     "AcquireTaskId return type isn't big enough to fit kMaxWorkersPerJob");
33 
34 }  // namespace
35 
36 JobTaskSourceNew::State::State() = default;
37 JobTaskSourceNew::State::~State() = default;
38 
Cancel()39 JobTaskSourceNew::State::Value JobTaskSourceNew::State::Cancel() {
40   return {value_.fetch_or(kCanceledMask, std::memory_order_relaxed)};
41 }
42 
IncrementWorkerCount()43 JobTaskSourceNew::State::Value JobTaskSourceNew::State::IncrementWorkerCount() {
44   uint32_t prev =
45       value_.fetch_add(kWorkerCountIncrement, std::memory_order_relaxed);
46   // The worker count must not overflow a uint8_t.
47   DCHECK((prev >> kWorkerCountBitOffset) < ((1 << 8) - 1));
48   return {prev};
49 }
50 
DecrementWorkerCount()51 JobTaskSourceNew::State::Value JobTaskSourceNew::State::DecrementWorkerCount() {
52   uint32_t prev =
53       value_.fetch_sub(kWorkerCountIncrement, std::memory_order_relaxed);
54   DCHECK((prev >> kWorkerCountBitOffset) > 0);
55   return {prev};
56 }
57 
RequestSignalJoin()58 JobTaskSourceNew::State::Value JobTaskSourceNew::State::RequestSignalJoin() {
59   uint32_t prev = value_.fetch_or(kSignalJoinMask, std::memory_order_relaxed);
60   return {prev};
61 }
62 
FetchAndResetRequestSignalJoin()63 bool JobTaskSourceNew::State::FetchAndResetRequestSignalJoin() {
64   uint32_t prev = value_.fetch_and(~kSignalJoinMask, std::memory_order_relaxed);
65   return !!(prev & kSignalJoinMask);
66 }
67 
ShouldQueueUponCapacityIncrease()68 bool JobTaskSourceNew::State::ShouldQueueUponCapacityIncrease() {
69   // If `WillRunTask()` is running: setting
70   // `kOutsideWillRunTaskOrMustReenqueueMask` ensures that this capacity
71   // increase is taken into account in the returned `RunStatus`.
72   //
73   // If `WillRunTask()` is not running, setting
74   // `kOutsideWillRunTaskOrMustReenqueueMask` is a no-op (already set).
75   //
76   // Release paired with Acquire in `ExitWillRunTask()`, see comment there.
77   Value prev{
78       value_.fetch_or(kQueuedMask | kOutsideWillRunTaskOrMustReenqueueMask,
79                       std::memory_order_release)};
80   return !prev.queued() && prev.outside_will_run_task_or_must_reenqueue();
81 }
82 
EnterWillRunTask()83 JobTaskSourceNew::State::Value JobTaskSourceNew::State::EnterWillRunTask() {
84   Value prev{
85       value_.fetch_and(~(kQueuedMask | kOutsideWillRunTaskOrMustReenqueueMask),
86                        std::memory_order_relaxed)};
87   CHECK(prev.outside_will_run_task_or_must_reenqueue());
88   return {prev};
89 }
90 
ExitWillRunTask(bool saturated)91 bool JobTaskSourceNew::State::ExitWillRunTask(bool saturated) {
92   uint32_t bits_to_set = kOutsideWillRunTaskOrMustReenqueueMask;
93   if (!saturated) {
94     // If the task source is not saturated, it will be re-enqueued.
95     bits_to_set |= kQueuedMask;
96   }
97 
98   // Acquire paired with Release in `ShouldQueueUponCapacityIncrease()` or
99   // `WillReenqueue()` so that anything that runs after clearing
100   // `kOutsideWillRunTaskOrMustReenqueueMask` sees max concurrency changes
101   // applied before setting it.
102   Value prev{value_.fetch_or(bits_to_set, std::memory_order_acquire)};
103 
104   // `kQueuedMask` and `kOutsideWillRunTaskOrMustReenqueueMask` were cleared by
105   // `EnterWillRunTask()`. Since then, they may have *both* been set by
106   //  `ShouldQueueUponCapacityIncrease()` or `WillReenqueue()`.
107   CHECK_EQ(prev.queued(), prev.outside_will_run_task_or_must_reenqueue());
108 
109   return prev.outside_will_run_task_or_must_reenqueue();
110 }
111 
WillReenqueue()112 bool JobTaskSourceNew::State::WillReenqueue() {
113   // Release paired with Acquire in `ExitWillRunTask()`, see comment there.
114   Value prev{
115       value_.fetch_or(kQueuedMask | kOutsideWillRunTaskOrMustReenqueueMask,
116                       std::memory_order_release)};
117   return prev.outside_will_run_task_or_must_reenqueue();
118 }
119 
Load() const120 JobTaskSourceNew::State::Value JobTaskSourceNew::State::Load() const {
121   return {value_.load(std::memory_order_relaxed)};
122 }
123 
JobTaskSourceNew(const Location & from_here,const TaskTraits & traits,RepeatingCallback<void (JobDelegate *)> worker_task,MaxConcurrencyCallback max_concurrency_callback,PooledTaskRunnerDelegate * delegate)124 JobTaskSourceNew::JobTaskSourceNew(
125     const Location& from_here,
126     const TaskTraits& traits,
127     RepeatingCallback<void(JobDelegate*)> worker_task,
128     MaxConcurrencyCallback max_concurrency_callback,
129     PooledTaskRunnerDelegate* delegate)
130     : JobTaskSource(traits, nullptr, TaskSourceExecutionMode::kJob),
131       max_concurrency_callback_(std::move(max_concurrency_callback)),
132       worker_task_(std::move(worker_task)),
133       primary_task_(base::BindRepeating(
134           [](JobTaskSourceNew* self) {
135             CheckedLock::AssertNoLockHeldOnCurrentThread();
136             // Each worker task has its own delegate with associated state.
137             JobDelegate job_delegate{self, self->delegate_};
138             self->worker_task_.Run(&job_delegate);
139           },
140           base::Unretained(this))),
141       task_metadata_(from_here),
142       ready_time_(TimeTicks::Now()),
143       delegate_(delegate) {
144   DCHECK(delegate_);
145   task_metadata_.sequence_num = -1;
146   // Prevent wait on `join_event_` from triggering a ScopedBlockingCall as this
147   // would acquire `ThreadGroup::lock_` and cause lock inversion.
148   join_event_.declare_only_used_while_idle();
149 }
150 
~JobTaskSourceNew()151 JobTaskSourceNew::~JobTaskSourceNew() {
152   // Make sure there's no outstanding active run operation left.
153   DCHECK_EQ(state_.Load().worker_count(), 0U);
154 }
155 
GetExecutionEnvironment()156 ExecutionEnvironment JobTaskSourceNew::GetExecutionEnvironment() {
157   return {SequenceToken::Create(), nullptr};
158 }
159 
WillEnqueue(int sequence_num,TaskAnnotator & annotator)160 void JobTaskSourceNew::WillEnqueue(int sequence_num, TaskAnnotator& annotator) {
161   if (task_metadata_.sequence_num != -1) {
162     // WillEnqueue() was already called.
163     return;
164   }
165   task_metadata_.sequence_num = sequence_num;
166   annotator.WillQueueTask("ThreadPool_PostJob", &task_metadata_);
167 }
168 
WillJoin()169 bool JobTaskSourceNew::WillJoin() {
170   // Increment worker count to indicate that this thread participates.
171   State::Value state_before_add;
172   {
173     CheckedAutoLock auto_lock(state_.increment_worker_count_lock());
174     state_before_add = state_.IncrementWorkerCount();
175   }
176 
177   // Return when the job is canceled or the (newly incremented) worker count is
178   // below or equal to max concurrency.
179   if (!state_before_add.canceled() &&
180       state_before_add.worker_count() <
181           GetMaxConcurrency(state_before_add.worker_count())) {
182     return true;
183   }
184   return WaitForParticipationOpportunity();
185 }
186 
RunJoinTask()187 bool JobTaskSourceNew::RunJoinTask() {
188   {
189     TRACE_EVENT0("base", "Job.JoinParticipates");
190     JobDelegate job_delegate{this, nullptr};
191     worker_task_.Run(&job_delegate);
192   }
193 
194   const auto state = state_.Load();
195   // The condition is slightly different from the one in WillJoin() since we're
196   // using |state| that was already incremented to include the joining thread.
197   if (!state.canceled() &&
198       state.worker_count() <= GetMaxConcurrency(state.worker_count() - 1)) {
199     return true;
200   }
201 
202   return WaitForParticipationOpportunity();
203 }
204 
Cancel(TaskSource::Transaction * transaction)205 void JobTaskSourceNew::Cancel(TaskSource::Transaction* transaction) {
206   // Sets the kCanceledMask bit on |state_| so that further calls to
207   // WillRunTask() never succeed. std::memory_order_relaxed without a lock is
208   // safe because this task source never needs to be re-enqueued after Cancel().
209   state_.Cancel();
210 }
211 
WaitForParticipationOpportunity()212 bool JobTaskSourceNew::WaitForParticipationOpportunity() {
213   TRACE_EVENT0("base", "Job.WaitForParticipationOpportunity");
214 
215   // Wait until either:
216   //  A) `worker_count` <= "max concurrency" and state is not canceled.
217   //  B) All other workers returned and `worker_count` is 1.
218   for (;;) {
219     auto state = state_.RequestSignalJoin();
220 
221     size_t max_concurrency = GetMaxConcurrency(state.worker_count() - 1);
222 
223     // Case A:
224     if (state.worker_count() <= max_concurrency && !state.canceled()) {
225       state_.FetchAndResetRequestSignalJoin();
226       return true;
227     }
228 
229     // Case B:
230     // Only the joining thread remains.
231     if (state.worker_count() == 1U) {
232       DCHECK(state.canceled() || max_concurrency == 0U);
233       // WillRunTask() can run concurrently with this. Synchronize with it via a
234       // lock to guarantee that the ordering is one of these 2 options:
235       // 1. WillRunTask is first. It increments worker count. The condition
236       //    below detects that worker count is no longer 1 and we loop again.
237       // 2. This runs first. It cancels the job. WillRunTask returns
238       //    RunStatus::kDisallowed and doesn't increment the worker count.
239       // We definitely don't want this 3rd option (made impossible by the lock):
240       // 3. WillRunTask() observes that the job is not canceled. This observes
241       //    that the worker count is 1 and returns. JobHandle::Join returns and
242       //    its owner deletes state needed by the worker task. WillRunTask()
243       //    increments the worker count and the worker task stats running -->
244       //    use-after-free.
245       CheckedAutoLock auto_lock(state_.increment_worker_count_lock());
246 
247       if (state_.Load().worker_count() != 1U) {
248         continue;
249       }
250 
251       state_.Cancel();
252       state_.FetchAndResetRequestSignalJoin();
253       state_.DecrementWorkerCount();
254       return false;
255     }
256 
257     join_event_.Wait();
258   }
259 }
260 
WillRunTask()261 TaskSource::RunStatus JobTaskSourceNew::WillRunTask() {
262   // The lock below prevents a race described in Case B of
263   // `WaitForParticipationOpportunity()`.
264   CheckedAutoLock auto_lock(state_.increment_worker_count_lock());
265 
266   for (;;) {
267     auto prev_state = state_.EnterWillRunTask();
268 
269     // Don't allow this worker to run the task if either:
270     //   A) Job was cancelled.
271     //   B) `worker_count` is already at `max_concurrency`.
272     //   C) `max_concurrency` was lowered below or to `worker_count`.
273 
274     // Case A:
275     if (prev_state.canceled()) {
276       state_.ExitWillRunTask(/* saturated=*/true);
277       return RunStatus::kDisallowed;
278     }
279 
280     const size_t worker_count_before_increment = prev_state.worker_count();
281     const size_t max_concurrency =
282         GetMaxConcurrency(worker_count_before_increment);
283 
284     if (worker_count_before_increment < max_concurrency) {
285       prev_state = state_.IncrementWorkerCount();
286       // Worker count may have been decremented since it was read, but not
287       // incremented, due to the lock.
288       CHECK_LE(prev_state.worker_count(), worker_count_before_increment);
289       bool saturated = max_concurrency == (worker_count_before_increment + 1);
290       bool concurrency_increased_during_will_run_task =
291           state_.ExitWillRunTask(saturated);
292 
293       if (saturated && !concurrency_increased_during_will_run_task) {
294         return RunStatus::kAllowedSaturated;
295       }
296 
297       return RunStatus::kAllowedNotSaturated;
298     }
299 
300     // Case B or C:
301     bool concurrency_increased_during_will_run_task =
302         state_.ExitWillRunTask(/* saturated=*/true);
303     if (!concurrency_increased_during_will_run_task) {
304       return RunStatus::kDisallowed;
305     }
306 
307     // If concurrency increased during `WillRunTask()`, loop again to
308     // re-evaluate the `RunStatus`.
309   }
310 }
311 
GetRemainingConcurrency() const312 size_t JobTaskSourceNew::GetRemainingConcurrency() const {
313   // It is safe to read |state_| without a lock since this variable is atomic,
314   // and no other state is synchronized with GetRemainingConcurrency().
315   const auto state = state_.Load();
316   if (state.canceled()) {
317     return 0;
318   }
319   const size_t max_concurrency = GetMaxConcurrency(state.worker_count());
320   // Avoid underflows.
321   if (state.worker_count() > max_concurrency)
322     return 0;
323   return max_concurrency - state.worker_count();
324 }
325 
IsActive() const326 bool JobTaskSourceNew::IsActive() const {
327   auto state = state_.Load();
328   return GetMaxConcurrency(state.worker_count()) != 0 ||
329          state.worker_count() != 0;
330 }
331 
GetWorkerCount() const332 size_t JobTaskSourceNew::GetWorkerCount() const {
333   return state_.Load().worker_count();
334 }
335 
NotifyConcurrencyIncrease()336 bool JobTaskSourceNew::NotifyConcurrencyIncrease() {
337   const auto state = state_.Load();
338 
339   // No need to signal the joining thread of re-enqueue if canceled.
340   if (state.canceled()) {
341     return true;
342   }
343 
344   const auto worker_count = state.worker_count();
345   const auto max_concurrency = GetMaxConcurrency(worker_count);
346 
347   // Signal the joining thread if there is a request to do so and there is room
348   // for the joining thread to participate.
349   if (worker_count <= max_concurrency &&
350       state_.FetchAndResetRequestSignalJoin()) {
351     join_event_.Signal();
352   }
353 
354   // The job should be queued if the max concurrency isn't reached and it's not
355   // already queued.
356   if (worker_count < max_concurrency &&
357       state_.ShouldQueueUponCapacityIncrease()) {
358     return delegate_->EnqueueJobTaskSource(this);
359   }
360 
361   return true;
362 }
363 
GetMaxConcurrency() const364 size_t JobTaskSourceNew::GetMaxConcurrency() const {
365   return GetMaxConcurrency(state_.Load().worker_count());
366 }
367 
GetMaxConcurrency(size_t worker_count) const368 size_t JobTaskSourceNew::GetMaxConcurrency(size_t worker_count) const {
369   return std::min(max_concurrency_callback_.Run(worker_count),
370                   kMaxWorkersPerJob);
371 }
372 
AcquireTaskId()373 uint8_t JobTaskSourceNew::AcquireTaskId() {
374   static_assert(kMaxWorkersPerJob <= sizeof(assigned_task_ids_) * 8,
375                 "TaskId bitfield isn't big enough to fit kMaxWorkersPerJob.");
376   uint32_t assigned_task_ids =
377       assigned_task_ids_.load(std::memory_order_relaxed);
378   uint32_t new_assigned_task_ids = 0;
379   int task_id = 0;
380   // memory_order_acquire on success, matched with memory_order_release in
381   // ReleaseTaskId() so that operations done by previous threads that had
382   // the same task_id become visible to the current thread.
383   do {
384     // Count trailing one bits. This is the id of the right-most 0-bit in
385     // |assigned_task_ids|.
386     task_id = std::countr_one(assigned_task_ids);
387     new_assigned_task_ids = assigned_task_ids | (uint32_t(1) << task_id);
388   } while (!assigned_task_ids_.compare_exchange_weak(
389       assigned_task_ids, new_assigned_task_ids, std::memory_order_acquire,
390       std::memory_order_relaxed));
391   return static_cast<uint8_t>(task_id);
392 }
393 
ReleaseTaskId(uint8_t task_id)394 void JobTaskSourceNew::ReleaseTaskId(uint8_t task_id) {
395   // memory_order_release to match AcquireTaskId().
396   uint32_t previous_task_ids = assigned_task_ids_.fetch_and(
397       ~(uint32_t(1) << task_id), std::memory_order_release);
398   DCHECK(previous_task_ids & (uint32_t(1) << task_id));
399 }
400 
ShouldYield()401 bool JobTaskSourceNew::ShouldYield() {
402   // It's safe to read `state_` without a lock because it's atomic, keeping in
403   // mind that threads may not immediately see the new value when it's updated.
404   return state_.Load().canceled();
405 }
406 
GetDelegate() const407 PooledTaskRunnerDelegate* JobTaskSourceNew::GetDelegate() const {
408   return delegate_;
409 }
410 
TakeTask(TaskSource::Transaction * transaction)411 Task JobTaskSourceNew::TakeTask(TaskSource::Transaction* transaction) {
412   // JobTaskSource members are not lock-protected so no need to acquire a lock
413   // if |transaction| is nullptr.
414   DCHECK_GT(state_.Load().worker_count(), 0U);
415   DCHECK(primary_task_);
416   return {task_metadata_, primary_task_};
417 }
418 
DidProcessTask(TaskSource::Transaction *)419 bool JobTaskSourceNew::DidProcessTask(
420     TaskSource::Transaction* /*transaction*/) {
421   auto state = state_.Load();
422   size_t worker_count_excluding_this = state.worker_count() - 1;
423 
424   // Invoke the max concurrency callback before decrementing the worker count,
425   // because as soon as the worker count is decremented, JobHandle::Join() can
426   // return and state needed the callback may be deleted. Also, as an
427   // optimization, avoid invoking the callback if the job is canceled.
428   size_t max_concurrency =
429       state.canceled() ? 0U : GetMaxConcurrency(worker_count_excluding_this);
430 
431   state = state_.DecrementWorkerCount();
432   if (state.signal_join() && state_.FetchAndResetRequestSignalJoin()) {
433     join_event_.Signal();
434   }
435 
436   // A canceled task source should not be re-enqueued.
437   if (state.canceled()) {
438     return false;
439   }
440 
441   // Re-enqueue if there isn't enough concurrency.
442   if (worker_count_excluding_this < max_concurrency) {
443     return state_.WillReenqueue();
444   }
445 
446   return false;
447 }
448 
449 // This is a no-op and should always return true.
WillReEnqueue(TimeTicks now,TaskSource::Transaction *)450 bool JobTaskSourceNew::WillReEnqueue(TimeTicks now,
451                                      TaskSource::Transaction* /*transaction*/) {
452   return true;
453 }
454 
455 // This is a no-op.
OnBecomeReady()456 bool JobTaskSourceNew::OnBecomeReady() {
457   return false;
458 }
459 
GetSortKey() const460 TaskSourceSortKey JobTaskSourceNew::GetSortKey() const {
461   return TaskSourceSortKey(priority_racy(), ready_time_,
462                            state_.Load().worker_count());
463 }
464 
465 // This function isn't expected to be called since a job is never delayed.
466 // However, the class still needs to provide an override.
GetDelayedSortKey() const467 TimeTicks JobTaskSourceNew::GetDelayedSortKey() const {
468   return TimeTicks();
469 }
470 
471 // This function isn't expected to be called since a job is never delayed.
472 // However, the class still needs to provide an override.
HasReadyTasks(TimeTicks now) const473 bool JobTaskSourceNew::HasReadyTasks(TimeTicks now) const {
474   NOTREACHED();
475   return true;
476 }
477 
Clear(TaskSource::Transaction * transaction)478 absl::optional<Task> JobTaskSourceNew::Clear(
479     TaskSource::Transaction* transaction) {
480   Cancel();
481 
482   // Nothing is cleared since other workers might still racily run tasks. For
483   // simplicity, the destructor will take care of it once all references are
484   // released.
485   return absl::nullopt;
486 }
487 
488 }  // namespace base::internal
489