1 /* Copyright 2019 Google LLC. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "ruy/thread_pool.h"
17
18 #include <atomic>
19 #include <chrono> // NOLINT(build/c++11)
20 #include <condition_variable> // NOLINT(build/c++11)
21 #include <cstdint>
22 #include <cstdlib>
23 #include <memory>
24 #include <mutex> // NOLINT(build/c++11)
25 #include <thread> // NOLINT(build/c++11)
26
27 #include "ruy/check_macros.h"
28 #include "ruy/trace.h"
29 #include "ruy/wait.h"
30
31 namespace ruy {
32
33 // A worker thread.
34 class Thread {
35 public:
36 enum class State {
37 Startup, // The initial state before the thread main loop runs.
38 Ready, // Is not working, has not yet received new work to do.
39 HasWork, // Has work to do.
40 ExitAsSoonAsPossible // Should exit at earliest convenience.
41 };
42
Thread(BlockingCounter * counter_to_decrement_when_ready,Duration spin_duration)43 explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
44 Duration spin_duration)
45 : task_(nullptr),
46 state_(State::Startup),
47 counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
48 spin_duration_(spin_duration) {
49 thread_.reset(new std::thread(ThreadFunc, this));
50 }
51
~Thread()52 ~Thread() {
53 ChangeState(State::ExitAsSoonAsPossible);
54 thread_->join();
55 }
56
57 // Changes State; may be called from either the worker thread
58 // or the master thread; however, not all state transitions are legal,
59 // which is guarded by assertions.
60 //
61 // The Task argument is to be used only with new_state==HasWork.
62 // It specifies the Task being handed to this Thread.
ChangeState(State new_state,Task * task=nullptr)63 void ChangeState(State new_state, Task* task = nullptr) {
64 state_mutex_.lock();
65 State old_state = state_.load(std::memory_order_relaxed);
66 RUY_DCHECK_NE(old_state, new_state);
67 switch (old_state) {
68 case State::Startup:
69 RUY_DCHECK_EQ(new_state, State::Ready);
70 break;
71 case State::Ready:
72 RUY_DCHECK(new_state == State::HasWork ||
73 new_state == State::ExitAsSoonAsPossible);
74 break;
75 case State::HasWork:
76 RUY_DCHECK(new_state == State::Ready ||
77 new_state == State::ExitAsSoonAsPossible);
78 break;
79 default:
80 abort();
81 }
82 switch (new_state) {
83 case State::Ready:
84 if (task_) {
85 // Doing work is part of reverting to 'ready' state.
86 task_->Run();
87 task_ = nullptr;
88 }
89 break;
90 case State::HasWork:
91 RUY_DCHECK(!task_);
92 task_ = task;
93 break;
94 default:
95 break;
96 }
97 state_.store(new_state, std::memory_order_relaxed);
98 state_cond_.notify_all();
99 state_mutex_.unlock();
100 if (new_state == State::Ready) {
101 counter_to_decrement_when_ready_->DecrementCount();
102 }
103 }
104
ThreadFunc(Thread * arg)105 static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
106
107 // Called by the master thead to give this thread work to do.
StartWork(Task * task)108 void StartWork(Task* task) { ChangeState(State::HasWork, task); }
109
110 private:
111 // Thread entry point.
ThreadFuncImpl()112 void ThreadFuncImpl() {
113 RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
114 ChangeState(State::Ready);
115
116 // Thread main loop
117 while (true) {
118 RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
119 // In the 'Ready' state, we have nothing to do but to wait until
120 // we switch to another state.
121 const auto& condition = [this]() {
122 return state_.load(std::memory_order_acquire) != State::Ready;
123 };
124 RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
125 Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
126
127 // Act on new state.
128 switch (state_.load(std::memory_order_acquire)) {
129 case State::HasWork: {
130 RUY_TRACE_SCOPE_NAME("Worker thread task");
131 // Got work to do! So do it, and then revert to 'Ready' state.
132 ChangeState(State::Ready);
133 break;
134 }
135 case State::ExitAsSoonAsPossible:
136 return;
137 default:
138 abort();
139 }
140 }
141 }
142
143 // The underlying thread.
144 std::unique_ptr<std::thread> thread_;
145
146 // The task to be worked on.
147 Task* task_;
148
149 // The condition variable and mutex guarding state changes.
150 std::condition_variable state_cond_;
151 std::mutex state_mutex_;
152
153 // The state enum tells if we're currently working, waiting for work, etc.
154 // Its concurrent accesses by the thread and main threads are guarded by
155 // state_mutex_, and can thus use memory_order_relaxed. This still needs
156 // to be a std::atomic because we use WaitForVariableChange.
157 std::atomic<State> state_;
158
159 // pointer to the master's thread BlockingCounter object, to notify the
160 // master thread of when this thread switches to the 'Ready' state.
161 BlockingCounter* const counter_to_decrement_when_ready_;
162
163 // See ThreadPool::spin_duration_.
164 const Duration spin_duration_;
165 };
166
ExecuteImpl(int task_count,int stride,Task * tasks)167 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
168 RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
169 RUY_DCHECK_GE(task_count, 1);
170
171 // Case of 1 thread: just run the single task on the current thread.
172 if (task_count == 1) {
173 (tasks + 0)->Run();
174 return;
175 }
176
177 // Task #0 will be run on the current thread.
178 CreateThreads(task_count - 1);
179 counter_to_decrement_when_ready_.Reset(task_count - 1);
180 for (int i = 1; i < task_count; i++) {
181 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
182 auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
183 threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
184 }
185
186 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
187 // Execute task #0 immediately on the current thread.
188 (tasks + 0)->Run();
189
190 RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
191 // Wait for the threads submitted above to finish.
192 counter_to_decrement_when_ready_.Wait(spin_duration_);
193 }
194
195 // Ensures that the pool has at least the given count of threads.
196 // If any new thread has to be created, this function waits for it to
197 // be ready.
CreateThreads(int threads_count)198 void ThreadPool::CreateThreads(int threads_count) {
199 RUY_DCHECK_GE(threads_count, 0);
200 unsigned int unsigned_threads_count = threads_count;
201 if (threads_.size() >= unsigned_threads_count) {
202 return;
203 }
204 counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
205 while (threads_.size() < unsigned_threads_count) {
206 threads_.push_back(
207 new Thread(&counter_to_decrement_when_ready_, spin_duration_));
208 }
209 counter_to_decrement_when_ready_.Wait(spin_duration_);
210 }
211
~ThreadPool()212 ThreadPool::~ThreadPool() {
213 for (auto w : threads_) {
214 delete w;
215 }
216 }
217
218 } // end namespace ruy
219