• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_TASK_THREAD_POOL_TEST_UTILS_H_
6 #define BASE_TASK_THREAD_POOL_TEST_UTILS_H_
7 
8 #include <atomic>
9 #include <memory>
10 
11 #include "base/functional/callback.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/task/common/checked_lock.h"
14 #include "base/task/post_job.h"
15 #include "base/task/task_features.h"
16 #include "base/task/task_runner.h"
17 #include "base/task/task_traits.h"
18 #include "base/task/thread_pool/delayed_task_manager.h"
19 #include "base/task/thread_pool/pooled_task_runner_delegate.h"
20 #include "base/task/thread_pool/sequence.h"
21 #include "base/task/thread_pool/task_tracker.h"
22 #include "base/task/thread_pool/thread_group.h"
23 #include "base/task/thread_pool/worker_thread_observer.h"
24 #include "build/build_config.h"
25 #include "testing/gmock/include/gmock/gmock.h"
26 #include "third_party/abseil-cpp/absl/types/variant.h"
27 
28 namespace base {
29 namespace internal {
30 
31 struct Task;
32 
33 namespace test {
34 
35 class MockWorkerThreadObserver : public WorkerThreadObserver {
36  public:
37   MockWorkerThreadObserver();
38   MockWorkerThreadObserver(const MockWorkerThreadObserver&) = delete;
39   MockWorkerThreadObserver& operator=(const MockWorkerThreadObserver&) = delete;
40   ~MockWorkerThreadObserver() override;
41 
42   void AllowCallsOnMainExit(int num_calls);
43   void WaitCallsOnMainExit();
44 
45   // WorkerThreadObserver:
46   MOCK_METHOD0(OnWorkerThreadMainEntry, void());
47   // This doesn't use MOCK_METHOD0 because some tests need to wait for all calls
48   // to happen, which isn't possible with gmock.
49   void OnWorkerThreadMainExit() override;
50 
51  private:
52   CheckedLock lock_;
53   std::unique_ptr<ConditionVariable> on_main_exit_cv_ GUARDED_BY(lock_);
54   int allowed_calls_on_main_exit_ GUARDED_BY(lock_) = 0;
55 };
56 
57 class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
58  public:
59   MockPooledTaskRunnerDelegate(TrackedRef<TaskTracker> task_tracker,
60                                DelayedTaskManager* delayed_task_manager);
61   ~MockPooledTaskRunnerDelegate() override;
62 
63   // PooledTaskRunnerDelegate:
64   bool PostTaskWithSequence(Task task,
65                             scoped_refptr<Sequence> sequence) override;
66   bool EnqueueJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
67   void RemoveJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
68   bool ShouldYield(const TaskSource* task_source) override;
69   void UpdatePriority(scoped_refptr<TaskSource> task_source,
70                       TaskPriority priority) override;
71   void UpdateJobPriority(scoped_refptr<TaskSource> task_source,
72                          TaskPriority priority) override;
73 
74   void SetThreadGroup(ThreadGroup* thread_group);
75 
76   void PostTaskWithSequenceNow(Task task, scoped_refptr<Sequence> sequence);
77 
78  private:
79   const TrackedRef<TaskTracker> task_tracker_;
80   const raw_ptr<DelayedTaskManager> delayed_task_manager_;
81   raw_ptr<ThreadGroup> thread_group_ = nullptr;
82 };
83 
84 // A simple MockJobTask that will give |worker_task| a fixed number of times,
85 // possibly in parallel.
86 class MockJobTask : public base::RefCountedThreadSafe<MockJobTask> {
87  public:
88   // Gives |worker_task| to requesting workers |num_tasks_to_run| times.
89   MockJobTask(RepeatingCallback<void(JobDelegate*)> worker_task,
90               size_t num_tasks_to_run);
91 
92   // Gives |worker_task| to a single requesting worker.
93   explicit MockJobTask(base::OnceClosure worker_task);
94 
95   MockJobTask(const MockJobTask&) = delete;
96   MockJobTask& operator=(const MockJobTask&) = delete;
97 
98   // Updates the remaining number of time |worker_task| runs to
99   // |num_tasks_to_run|.
100   void SetNumTasksToRun(size_t num_tasks_to_run);
101 
102   size_t GetMaxConcurrency(size_t worker_count) const;
103   void Run(JobDelegate* delegate);
104 
105   scoped_refptr<JobTaskSource> GetJobTaskSource(
106       const Location& from_here,
107       const TaskTraits& traits,
108       PooledTaskRunnerDelegate* delegate);
109 
110  private:
111   friend class base::RefCountedThreadSafe<MockJobTask>;
112 
113   ~MockJobTask();
114 
115   absl::variant<OnceClosure, RepeatingCallback<void(JobDelegate*)>> task_;
116   std::atomic_size_t remaining_num_tasks_to_run_;
117 };
118 
119 // Creates a Sequence with given |traits| and pushes |task| to it. If a
120 // TaskRunner is associated with |task|, it should be be passed as |task_runner|
121 // along with its |execution_mode|. Returns the created Sequence.
122 scoped_refptr<Sequence> CreateSequenceWithTask(
123     Task task,
124     const TaskTraits& traits,
125     scoped_refptr<TaskRunner> task_runner = nullptr,
126     TaskSourceExecutionMode execution_mode =
127         TaskSourceExecutionMode::kParallel);
128 
129 // Creates a TaskRunner that posts tasks to the thread group owned by
130 // |pooled_task_runner_delegate| with the |execution_mode|.
131 // Caveat: this does not support TaskSourceExecutionMode::kSingleThread.
132 scoped_refptr<TaskRunner> CreatePooledTaskRunnerWithExecutionMode(
133     TaskSourceExecutionMode execution_mode,
134     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate,
135     const TaskTraits& traits = {});
136 
137 scoped_refptr<TaskRunner> CreatePooledTaskRunner(
138     const TaskTraits& traits,
139     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate);
140 
141 scoped_refptr<SequencedTaskRunner> CreatePooledSequencedTaskRunner(
142     const TaskTraits& traits,
143     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate);
144 
145 RegisteredTaskSource QueueAndRunTaskSource(
146     TaskTracker* task_tracker,
147     scoped_refptr<TaskSource> task_source);
148 
149 // Calls StartShutdown() and CompleteShutdown() on |task_tracker|.
150 void ShutdownTaskTracker(TaskTracker* task_tracker);
151 
152 }  // namespace test
153 }  // namespace internal
154 }  // namespace base
155 
156 #endif  // BASE_TASK_THREAD_POOL_TEST_UTILS_H_
157