• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/thread_pool.h"
17 
18 #include <atomic>
19 #include <chrono>              // NOLINT(build/c++11)
20 #include <condition_variable>  // NOLINT(build/c++11)
21 #include <cstdint>
22 #include <cstdlib>
23 #include <memory>
24 #include <mutex>   // NOLINT(build/c++11)
25 #include <thread>  // NOLINT(build/c++11)
26 
27 #include "ruy/check_macros.h"
28 #include "ruy/trace.h"
29 #include "ruy/wait.h"
30 
31 namespace ruy {
32 
33 // A worker thread.
34 class Thread {
35  public:
36   enum class State {
37     Startup,  // The initial state before the thread main loop runs.
38     Ready,    // Is not working, has not yet received new work to do.
39     HasWork,  // Has work to do.
40     ExitAsSoonAsPossible  // Should exit at earliest convenience.
41   };
42 
Thread(BlockingCounter * counter_to_decrement_when_ready,Duration spin_duration)43   explicit Thread(BlockingCounter* counter_to_decrement_when_ready,
44                   Duration spin_duration)
45       : task_(nullptr),
46         state_(State::Startup),
47         counter_to_decrement_when_ready_(counter_to_decrement_when_ready),
48         spin_duration_(spin_duration) {
49     thread_.reset(new std::thread(ThreadFunc, this));
50   }
51 
~Thread()52   ~Thread() {
53     ChangeState(State::ExitAsSoonAsPossible);
54     thread_->join();
55   }
56 
57   // Changes State; may be called from either the worker thread
58   // or the master thread; however, not all state transitions are legal,
59   // which is guarded by assertions.
60   //
61   // The Task argument is to be used only with new_state==HasWork.
62   // It specifies the Task being handed to this Thread.
ChangeState(State new_state,Task * task=nullptr)63   void ChangeState(State new_state, Task* task = nullptr) {
64     state_mutex_.lock();
65     State old_state = state_.load(std::memory_order_relaxed);
66     RUY_DCHECK_NE(old_state, new_state);
67     switch (old_state) {
68       case State::Startup:
69         RUY_DCHECK_EQ(new_state, State::Ready);
70         break;
71       case State::Ready:
72         RUY_DCHECK(new_state == State::HasWork ||
73                    new_state == State::ExitAsSoonAsPossible);
74         break;
75       case State::HasWork:
76         RUY_DCHECK(new_state == State::Ready ||
77                    new_state == State::ExitAsSoonAsPossible);
78         break;
79       default:
80         abort();
81     }
82     switch (new_state) {
83       case State::Ready:
84         if (task_) {
85           // Doing work is part of reverting to 'ready' state.
86           task_->Run();
87           task_ = nullptr;
88         }
89         break;
90       case State::HasWork:
91         RUY_DCHECK(!task_);
92         task_ = task;
93         break;
94       default:
95         break;
96     }
97     state_.store(new_state, std::memory_order_relaxed);
98     state_cond_.notify_all();
99     state_mutex_.unlock();
100     if (new_state == State::Ready) {
101       counter_to_decrement_when_ready_->DecrementCount();
102     }
103   }
104 
ThreadFunc(Thread * arg)105   static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
106 
107   // Called by the master thead to give this thread work to do.
StartWork(Task * task)108   void StartWork(Task* task) { ChangeState(State::HasWork, task); }
109 
110  private:
111   // Thread entry point.
ThreadFuncImpl()112   void ThreadFuncImpl() {
113     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
114     ChangeState(State::Ready);
115 
116     // Thread main loop
117     while (true) {
118       RUY_TRACE_SCOPE_NAME("Ruy worker thread loop iteration");
119       // In the 'Ready' state, we have nothing to do but to wait until
120       // we switch to another state.
121       const auto& condition = [this]() {
122         return state_.load(std::memory_order_acquire) != State::Ready;
123       };
124       RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
125       Wait(condition, spin_duration_, &state_cond_, &state_mutex_);
126 
127       // Act on new state.
128       switch (state_.load(std::memory_order_acquire)) {
129         case State::HasWork: {
130           RUY_TRACE_SCOPE_NAME("Worker thread task");
131           // Got work to do! So do it, and then revert to 'Ready' state.
132           ChangeState(State::Ready);
133           break;
134         }
135         case State::ExitAsSoonAsPossible:
136           return;
137         default:
138           abort();
139       }
140     }
141   }
142 
143   // The underlying thread.
144   std::unique_ptr<std::thread> thread_;
145 
146   // The task to be worked on.
147   Task* task_;
148 
149   // The condition variable and mutex guarding state changes.
150   std::condition_variable state_cond_;
151   std::mutex state_mutex_;
152 
153   // The state enum tells if we're currently working, waiting for work, etc.
154   // Its concurrent accesses by the thread and main threads are guarded by
155   // state_mutex_, and can thus use memory_order_relaxed. This still needs
156   // to be a std::atomic because we use WaitForVariableChange.
157   std::atomic<State> state_;
158 
159   // pointer to the master's thread BlockingCounter object, to notify the
160   // master thread of when this thread switches to the 'Ready' state.
161   BlockingCounter* const counter_to_decrement_when_ready_;
162 
163   // See ThreadPool::spin_duration_.
164   const Duration spin_duration_;
165 };
166 
ExecuteImpl(int task_count,int stride,Task * tasks)167 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
168   RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
169   RUY_DCHECK_GE(task_count, 1);
170 
171   // Case of 1 thread: just run the single task on the current thread.
172   if (task_count == 1) {
173     (tasks + 0)->Run();
174     return;
175   }
176 
177   // Task #0 will be run on the current thread.
178   CreateThreads(task_count - 1);
179   counter_to_decrement_when_ready_.Reset(task_count - 1);
180   for (int i = 1; i < task_count; i++) {
181     RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
182     auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
183     threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
184   }
185 
186   RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
187   // Execute task #0 immediately on the current thread.
188   (tasks + 0)->Run();
189 
190   RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
191   // Wait for the threads submitted above to finish.
192   counter_to_decrement_when_ready_.Wait(spin_duration_);
193 }
194 
195 // Ensures that the pool has at least the given count of threads.
196 // If any new thread has to be created, this function waits for it to
197 // be ready.
CreateThreads(int threads_count)198 void ThreadPool::CreateThreads(int threads_count) {
199   RUY_DCHECK_GE(threads_count, 0);
200   unsigned int unsigned_threads_count = threads_count;
201   if (threads_.size() >= unsigned_threads_count) {
202     return;
203   }
204   counter_to_decrement_when_ready_.Reset(threads_count - threads_.size());
205   while (threads_.size() < unsigned_threads_count) {
206     threads_.push_back(
207         new Thread(&counter_to_decrement_when_ready_, spin_duration_));
208   }
209   counter_to_decrement_when_ready_.Wait(spin_duration_);
210 }
211 
~ThreadPool()212 ThreadPool::~ThreadPool() {
213   for (auto w : threads_) {
214     delete w;
215   }
216 }
217 
218 }  // end namespace ruy
219