1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/task/thread_pool/test_utils.h"
6
7 #include <utility>
8
9 #include "base/check.h"
10 #include "base/debug/leak_annotations.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/overloaded.h"
13 #include "base/memory/raw_ptr.h"
14 #include "base/synchronization/condition_variable.h"
15 #include "base/task/thread_pool/pooled_parallel_task_runner.h"
16 #include "base/task/thread_pool/pooled_sequenced_task_runner.h"
17 #include "base/test/bind.h"
18 #include "base/threading/scoped_blocking_call_internal.h"
19 #include "base/threading/thread_restrictions.h"
20 #include "testing/gtest/include/gtest/gtest.h"
21 #include "third_party/abseil-cpp/absl/types/variant.h"
22
23 namespace base {
24 namespace internal {
25 namespace test {
26
27 namespace {
28
29 // A task runner that posts each task as a MockJobTaskSource that runs a single
30 // task. This is used to run ThreadGroupTests which require a TaskRunner with
31 // kJob execution mode. Delayed tasks are not supported.
32 class MockJobTaskRunner : public TaskRunner {
33 public:
MockJobTaskRunner(const TaskTraits & traits,PooledTaskRunnerDelegate * pooled_task_runner_delegate)34 MockJobTaskRunner(const TaskTraits& traits,
35 PooledTaskRunnerDelegate* pooled_task_runner_delegate)
36 : traits_(traits),
37 pooled_task_runner_delegate_(pooled_task_runner_delegate) {}
38
39 MockJobTaskRunner(const MockJobTaskRunner&) = delete;
40 MockJobTaskRunner& operator=(const MockJobTaskRunner&) = delete;
41
42 // TaskRunner:
43 bool PostDelayedTask(const Location& from_here,
44 OnceClosure closure,
45 TimeDelta delay) override;
46
47 private:
48 ~MockJobTaskRunner() override = default;
49
50 const TaskTraits traits_;
51 const raw_ptr<PooledTaskRunnerDelegate> pooled_task_runner_delegate_;
52 };
53
PostDelayedTask(const Location & from_here,OnceClosure closure,TimeDelta delay)54 bool MockJobTaskRunner::PostDelayedTask(const Location& from_here,
55 OnceClosure closure,
56 TimeDelta delay) {
57 DCHECK_EQ(delay, TimeDelta()); // Jobs doesn't support delayed tasks.
58
59 if (!PooledTaskRunnerDelegate::MatchesCurrentDelegate(
60 pooled_task_runner_delegate_)) {
61 return false;
62 }
63
64 auto job_task = base::MakeRefCounted<MockJobTask>(std::move(closure));
65 scoped_refptr<JobTaskSource> task_source = job_task->GetJobTaskSource(
66 from_here, traits_, pooled_task_runner_delegate_);
67 return task_source->NotifyConcurrencyIncrease();
68 }
69
CreateJobTaskRunner(const TaskTraits & traits,MockPooledTaskRunnerDelegate * mock_pooled_task_runner_delegate)70 scoped_refptr<TaskRunner> CreateJobTaskRunner(
71 const TaskTraits& traits,
72 MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate) {
73 return MakeRefCounted<MockJobTaskRunner>(traits,
74 mock_pooled_task_runner_delegate);
75 }
76
77 } // namespace
78
MockWorkerThreadObserver()79 MockWorkerThreadObserver::MockWorkerThreadObserver()
80 : on_main_exit_cv_(lock_.CreateConditionVariable()) {}
81
~MockWorkerThreadObserver()82 MockWorkerThreadObserver::~MockWorkerThreadObserver() {
83 WaitCallsOnMainExit();
84 }
85
AllowCallsOnMainExit(int num_calls)86 void MockWorkerThreadObserver::AllowCallsOnMainExit(int num_calls) {
87 CheckedAutoLock auto_lock(lock_);
88 EXPECT_EQ(0, allowed_calls_on_main_exit_);
89 allowed_calls_on_main_exit_ = num_calls;
90 }
91
WaitCallsOnMainExit()92 void MockWorkerThreadObserver::WaitCallsOnMainExit() {
93 CheckedAutoLock auto_lock(lock_);
94 while (allowed_calls_on_main_exit_ != 0)
95 on_main_exit_cv_->Wait();
96 }
97
OnWorkerThreadMainExit()98 void MockWorkerThreadObserver::OnWorkerThreadMainExit() {
99 CheckedAutoLock auto_lock(lock_);
100 EXPECT_GE(allowed_calls_on_main_exit_, 0);
101 --allowed_calls_on_main_exit_;
102 if (allowed_calls_on_main_exit_ == 0)
103 on_main_exit_cv_->Signal();
104 }
105
CreateSequenceWithTask(Task task,const TaskTraits & traits,scoped_refptr<TaskRunner> task_runner,TaskSourceExecutionMode execution_mode)106 scoped_refptr<Sequence> CreateSequenceWithTask(
107 Task task,
108 const TaskTraits& traits,
109 scoped_refptr<TaskRunner> task_runner,
110 TaskSourceExecutionMode execution_mode) {
111 scoped_refptr<Sequence> sequence =
112 MakeRefCounted<Sequence>(traits, task_runner.get(), execution_mode);
113 auto transaction = sequence->BeginTransaction();
114 transaction.WillPushImmediateTask();
115 transaction.PushImmediateTask(std::move(task));
116 return sequence;
117 }
118
CreatePooledTaskRunnerWithExecutionMode(TaskSourceExecutionMode execution_mode,MockPooledTaskRunnerDelegate * mock_pooled_task_runner_delegate,const TaskTraits & traits)119 scoped_refptr<TaskRunner> CreatePooledTaskRunnerWithExecutionMode(
120 TaskSourceExecutionMode execution_mode,
121 MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate,
122 const TaskTraits& traits) {
123 switch (execution_mode) {
124 case TaskSourceExecutionMode::kParallel:
125 return CreatePooledTaskRunner(traits, mock_pooled_task_runner_delegate);
126 case TaskSourceExecutionMode::kSequenced:
127 return CreatePooledSequencedTaskRunner(traits,
128 mock_pooled_task_runner_delegate);
129 case TaskSourceExecutionMode::kJob:
130 return CreateJobTaskRunner(traits, mock_pooled_task_runner_delegate);
131 default:
132 // Fall through.
133 break;
134 }
135 ADD_FAILURE() << "Unexpected ExecutionMode";
136 return nullptr;
137 }
138
CreatePooledTaskRunner(const TaskTraits & traits,MockPooledTaskRunnerDelegate * mock_pooled_task_runner_delegate)139 scoped_refptr<TaskRunner> CreatePooledTaskRunner(
140 const TaskTraits& traits,
141 MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate) {
142 return MakeRefCounted<PooledParallelTaskRunner>(
143 traits, mock_pooled_task_runner_delegate);
144 }
145
CreatePooledSequencedTaskRunner(const TaskTraits & traits,MockPooledTaskRunnerDelegate * mock_pooled_task_runner_delegate)146 scoped_refptr<SequencedTaskRunner> CreatePooledSequencedTaskRunner(
147 const TaskTraits& traits,
148 MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate) {
149 return MakeRefCounted<PooledSequencedTaskRunner>(
150 traits, mock_pooled_task_runner_delegate);
151 }
152
MockPooledTaskRunnerDelegate(TrackedRef<TaskTracker> task_tracker,DelayedTaskManager * delayed_task_manager)153 MockPooledTaskRunnerDelegate::MockPooledTaskRunnerDelegate(
154 TrackedRef<TaskTracker> task_tracker,
155 DelayedTaskManager* delayed_task_manager)
156 : task_tracker_(task_tracker),
157 delayed_task_manager_(delayed_task_manager) {}
158
159 MockPooledTaskRunnerDelegate::~MockPooledTaskRunnerDelegate() = default;
160
PostTaskWithSequence(Task task,scoped_refptr<Sequence> sequence)161 bool MockPooledTaskRunnerDelegate::PostTaskWithSequence(
162 Task task,
163 scoped_refptr<Sequence> sequence) {
164 // |thread_group_| must be initialized with SetThreadGroup() before
165 // proceeding.
166 DCHECK(thread_group_);
167 DCHECK(task.task);
168 DCHECK(sequence);
169
170 if (!task_tracker_->WillPostTask(&task, sequence->shutdown_behavior())) {
171 // `task`'s destructor may run sequence-affine code, so it must be leaked
172 // when `WillPostTask` returns false.
173 auto leak = std::make_unique<Task>(std::move(task));
174 ANNOTATE_LEAKING_OBJECT_PTR(leak.get());
175 leak.release();
176 return false;
177 }
178
179 if (task.delayed_run_time.is_null()) {
180 PostTaskWithSequenceNow(std::move(task), std::move(sequence));
181 } else {
182 // It's safe to take a ref on this pointer since the caller must have a ref
183 // to the TaskRunner in order to post.
184 scoped_refptr<TaskRunner> task_runner = sequence->task_runner();
185 delayed_task_manager_->AddDelayedTask(
186 std::move(task),
187 BindOnce(
188 [](scoped_refptr<Sequence> sequence,
189 MockPooledTaskRunnerDelegate* self, Task task) {
190 self->PostTaskWithSequenceNow(std::move(task),
191 std::move(sequence));
192 },
193 std::move(sequence), Unretained(this)),
194 std::move(task_runner));
195 }
196
197 return true;
198 }
199
PostTaskWithSequenceNow(Task task,scoped_refptr<Sequence> sequence)200 void MockPooledTaskRunnerDelegate::PostTaskWithSequenceNow(
201 Task task,
202 scoped_refptr<Sequence> sequence) {
203 auto transaction = sequence->BeginTransaction();
204 const bool sequence_should_be_queued = transaction.WillPushImmediateTask();
205 RegisteredTaskSource task_source;
206 if (sequence_should_be_queued) {
207 task_source = task_tracker_->RegisterTaskSource(std::move(sequence));
208 // We shouldn't push |task| if we're not allowed to queue |task_source|.
209 if (!task_source)
210 return;
211 }
212 transaction.PushImmediateTask(std::move(task));
213 if (task_source) {
214 thread_group_->PushTaskSourceAndWakeUpWorkers(
215 {std::move(task_source), std::move(transaction)});
216 }
217 }
218
ShouldYield(const TaskSource * task_source)219 bool MockPooledTaskRunnerDelegate::ShouldYield(const TaskSource* task_source) {
220 return thread_group_->ShouldYield(task_source->GetSortKey());
221 }
222
EnqueueJobTaskSource(scoped_refptr<JobTaskSource> task_source)223 bool MockPooledTaskRunnerDelegate::EnqueueJobTaskSource(
224 scoped_refptr<JobTaskSource> task_source) {
225 // |thread_group_| must be initialized with SetThreadGroup() before
226 // proceeding.
227 DCHECK(thread_group_);
228 DCHECK(task_source);
229
230 auto registered_task_source =
231 task_tracker_->RegisterTaskSource(std::move(task_source));
232 if (!registered_task_source)
233 return false;
234 auto transaction = registered_task_source->BeginTransaction();
235 thread_group_->PushTaskSourceAndWakeUpWorkers(
236 {std::move(registered_task_source), std::move(transaction)});
237 return true;
238 }
239
RemoveJobTaskSource(scoped_refptr<JobTaskSource> task_source)240 void MockPooledTaskRunnerDelegate::RemoveJobTaskSource(
241 scoped_refptr<JobTaskSource> task_source) {
242 thread_group_->RemoveTaskSource(*task_source);
243 }
244
UpdatePriority(scoped_refptr<TaskSource> task_source,TaskPriority priority)245 void MockPooledTaskRunnerDelegate::UpdatePriority(
246 scoped_refptr<TaskSource> task_source,
247 TaskPriority priority) {
248 auto transaction = task_source->BeginTransaction();
249 transaction.UpdatePriority(priority);
250 thread_group_->UpdateSortKey(std::move(transaction));
251 }
252
UpdateJobPriority(scoped_refptr<TaskSource> task_source,TaskPriority priority)253 void MockPooledTaskRunnerDelegate::UpdateJobPriority(
254 scoped_refptr<TaskSource> task_source,
255 TaskPriority priority) {
256 UpdatePriority(std::move(task_source), priority);
257 }
258
SetThreadGroup(ThreadGroup * thread_group)259 void MockPooledTaskRunnerDelegate::SetThreadGroup(ThreadGroup* thread_group) {
260 thread_group_ = thread_group;
261 }
262
263 MockJobTask::~MockJobTask() = default;
264
MockJobTask(base::RepeatingCallback<void (JobDelegate *)> worker_task,size_t num_tasks_to_run)265 MockJobTask::MockJobTask(
266 base::RepeatingCallback<void(JobDelegate*)> worker_task,
267 size_t num_tasks_to_run)
268 : task_(std::move(worker_task)),
269 remaining_num_tasks_to_run_(num_tasks_to_run) {
270 CHECK(!absl::get<decltype(worker_task)>(task_).is_null());
271 }
272
MockJobTask(base::OnceClosure worker_task)273 MockJobTask::MockJobTask(base::OnceClosure worker_task)
274 : task_(std::move(worker_task)), remaining_num_tasks_to_run_(1) {
275 CHECK(!absl::get<decltype(worker_task)>(task_).is_null());
276 }
277
SetNumTasksToRun(size_t num_tasks_to_run)278 void MockJobTask::SetNumTasksToRun(size_t num_tasks_to_run) {
279 if (num_tasks_to_run == 0) {
280 remaining_num_tasks_to_run_ = 0;
281 return;
282 }
283 if (auto* closure = absl::get_if<base::OnceClosure>(&task_); closure) {
284 // 0 is already handled above, so this can only be an attempt to set to
285 // a non-zero value for a OnceClosure. In that case, the only permissible
286 // value is 1, and the closure must not be null.
287 //
288 // Note that there is no need to check `!is_null()` for repeating callbacks,
289 // since `Run(JobDelegate*)` never consumes the repeating callback variant.
290 CHECK(!closure->is_null());
291 CHECK_EQ(1u, num_tasks_to_run);
292 }
293 remaining_num_tasks_to_run_ = num_tasks_to_run;
294 }
295
GetMaxConcurrency(size_t) const296 size_t MockJobTask::GetMaxConcurrency(size_t /* worker_count */) const {
297 return remaining_num_tasks_to_run_.load();
298 }
299
Run(JobDelegate * delegate)300 void MockJobTask::Run(JobDelegate* delegate) {
301 absl::visit(
302 base::Overloaded{
303 [](OnceClosure& closure) { std::move(closure).Run(); },
304 [delegate](const RepeatingCallback<void(JobDelegate*)>& callback) {
305 callback.Run(delegate);
306 }},
307 task_);
308 CHECK_GT(remaining_num_tasks_to_run_.fetch_sub(1), 0u);
309 }
310
GetJobTaskSource(const Location & from_here,const TaskTraits & traits,PooledTaskRunnerDelegate * delegate)311 scoped_refptr<JobTaskSource> MockJobTask::GetJobTaskSource(
312 const Location& from_here,
313 const TaskTraits& traits,
314 PooledTaskRunnerDelegate* delegate) {
315 return CreateJobTaskSource(
316 from_here, traits, base::BindRepeating(&test::MockJobTask::Run, this),
317 base::BindRepeating(&test::MockJobTask::GetMaxConcurrency, this),
318 delegate);
319 }
320
QueueAndRunTaskSource(TaskTracker * task_tracker,scoped_refptr<TaskSource> task_source)321 RegisteredTaskSource QueueAndRunTaskSource(
322 TaskTracker* task_tracker,
323 scoped_refptr<TaskSource> task_source) {
324 auto registered_task_source =
325 task_tracker->RegisterTaskSource(std::move(task_source));
326 EXPECT_TRUE(registered_task_source);
327 EXPECT_NE(registered_task_source.WillRunTask(),
328 TaskSource::RunStatus::kDisallowed);
329 return task_tracker->RunAndPopNextTask(std::move(registered_task_source));
330 }
331
ShutdownTaskTracker(TaskTracker * task_tracker)332 void ShutdownTaskTracker(TaskTracker* task_tracker) {
333 task_tracker->StartShutdown();
334 task_tracker->CompleteShutdown();
335 }
336
337 } // namespace test
338 } // namespace internal
339 } // namespace base
340