1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/threadpool.h"
17
18 #define EIGEN_USE_THREADS
19
20 #include "absl/types/optional.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/platform/blocking_counter.h"
23 #include "tensorflow/core/platform/context.h"
24 #include "tensorflow/core/platform/denormal.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/numa.h"
28 #include "tensorflow/core/platform/setround.h"
29 #include "tensorflow/core/platform/tracing.h"
30
31 namespace tensorflow {
32 namespace thread {
33
34 struct EigenEnvironment {
35 typedef Thread EnvThread;
36 struct TaskImpl {
37 std::function<void()> f;
38 Context context;
39 uint64 trace_id;
40 };
41 struct Task {
42 std::unique_ptr<TaskImpl> f;
43 };
44
45 Env* const env_;
46 const ThreadOptions thread_options_;
47 const string name_;
48
EigenEnvironmenttensorflow::thread::EigenEnvironment49 EigenEnvironment(Env* env, const ThreadOptions& thread_options,
50 const string& name)
51 : env_(env), thread_options_(thread_options), name_(name) {}
52
CreateThreadtensorflow::thread::EigenEnvironment53 EnvThread* CreateThread(std::function<void()> f) {
54 return env_->StartThread(thread_options_, name_, [=]() {
55 // Set the processor flag to flush denormals to zero.
56 port::ScopedFlushDenormal flush;
57 // Set the processor rounding mode to ROUND TO NEAREST.
58 port::ScopedSetRound round(FE_TONEAREST);
59 if (thread_options_.numa_node != port::kNUMANoAffinity) {
60 port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
61 }
62 f();
63 });
64 }
65
CreateTasktensorflow::thread::EigenEnvironment66 Task CreateTask(std::function<void()> f) {
67 uint64 id = 0;
68 if (tracing::EventCollector::IsEnabled()) {
69 id = tracing::GetUniqueArg();
70 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
71 }
72 return Task{
73 std::unique_ptr<TaskImpl>(new TaskImpl{
74 std::move(f),
75 Context(ContextKind::kThread),
76 id,
77 }),
78 };
79 }
80
ExecuteTasktensorflow::thread::EigenEnvironment81 void ExecuteTask(const Task& t) {
82 WithContext wc(t.f->context);
83 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
84 t.f->trace_id);
85 t.f->f();
86 }
87 };
88
ThreadPool(Env * env,const string & name,int num_threads)89 ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
90 : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
91
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads)92 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
93 const string& name, int num_threads)
94 : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
95
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,Eigen::Allocator * allocator)96 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
97 const string& name, int num_threads,
98 bool low_latency_hint, Eigen::Allocator* allocator) {
99 CHECK_GE(num_threads, 1);
100 eigen_threadpool_.reset(new Eigen::ThreadPoolTempl<EigenEnvironment>(
101 num_threads, low_latency_hint,
102 EigenEnvironment(env, thread_options, "tf_" + name)));
103 underlying_threadpool_ = eigen_threadpool_.get();
104 threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_,
105 num_threads, allocator));
106 }
107
ThreadPool(thread::ThreadPoolInterface * user_threadpool)108 ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) {
109 underlying_threadpool_ = user_threadpool;
110 threadpool_device_.reset(new Eigen::ThreadPoolDevice(
111 underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr));
112 }
113
~ThreadPool()114 ThreadPool::~ThreadPool() {}
115
Schedule(std::function<void ()> fn)116 void ThreadPool::Schedule(std::function<void()> fn) {
117 CHECK(fn != nullptr);
118 underlying_threadpool_->Schedule(std::move(fn));
119 }
120
NumShardsUsedByFixedBlockSizeScheduling(const int64 total,const int64 block_size)121 int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(
122 const int64 total, const int64 block_size) {
123 if (block_size <= 0 || total <= 1 || total <= block_size ||
124 NumThreads() == 1) {
125 return 1;
126 }
127 return (total + block_size - 1) / block_size;
128 }
129
NumShardsUsedByTransformRangeConcurrently(const int64 block_size,const int64 total)130 int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
131 const int64 block_size, const int64 total) {
132 return NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
133 }
134
ParallelFor(int64 total,const SchedulingParams & scheduling_params,const std::function<void (int64,int64)> & fn)135 void ThreadPool::ParallelFor(int64 total,
136 const SchedulingParams& scheduling_params,
137 const std::function<void(int64, int64)>& fn) {
138 switch (scheduling_params.strategy()) {
139 case SchedulingStrategy::kAdaptive: {
140 if (scheduling_params.cost_per_unit().has_value()) {
141 ParallelFor(total, *scheduling_params.cost_per_unit(), fn);
142 }
143 break;
144 }
145 case SchedulingStrategy::kFixedBlockSize: {
146 if (scheduling_params.block_size().has_value()) {
147 ParallelForFixedBlockSizeScheduling(
148 total, *scheduling_params.block_size(), fn);
149 }
150 break;
151 }
152 }
153 }
154
TransformRangeConcurrently(const int64 block_size,const int64 total,const std::function<void (int64,int64)> & fn)155 void ThreadPool::TransformRangeConcurrently(
156 const int64 block_size, const int64 total,
157 const std::function<void(int64, int64)>& fn) {
158 ParallelFor(total,
159 SchedulingParams(SchedulingStrategy::kFixedBlockSize,
160 absl::nullopt /* cost_per_unit */, block_size),
161 fn);
162 }
163
164 // This functionality is similar to parallelFor, except that reasoning about
165 // the number of shards used is significantly easier.
ParallelForFixedBlockSizeScheduling(const int64 total,const int64 block_size,const std::function<void (int64,int64)> & fn)166 void ThreadPool::ParallelForFixedBlockSizeScheduling(
167 const int64 total, const int64 block_size,
168 const std::function<void(int64, int64)>& fn) {
169 const int num_shards_used =
170 NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
171 if (num_shards_used == 1) {
172 fn(0, total);
173 return;
174 }
175
176 // Adapted from Eigen's parallelFor implementation.
177 BlockingCounter counter(num_shards_used);
178 std::function<void(int64, int64)> handle_range =
179 [=, &handle_range, &counter, &fn](int64 first, int64 last) {
180 while (last - first > block_size) {
181 // Find something near the midpoint which is a multiple of block size.
182 const int64 mid = first + ((last - first) / 2 + block_size - 1) /
183 block_size * block_size;
184 Schedule([=, &handle_range]() { handle_range(mid, last); });
185 last = mid;
186 }
187 // Single block or less, execute directly.
188 fn(first, last);
189 counter.DecrementCount(); // The shard is done.
190 };
191 if (num_shards_used <= NumThreads()) {
192 // Avoid a thread hop by running the root of the tree and one block on the
193 // main thread.
194 handle_range(0, total);
195 } else {
196 // Execute the root in the thread pool to avoid running work on more than
197 // numThreads() threads.
198 Schedule([=, &handle_range]() { handle_range(0, total); });
199 }
200 counter.Wait();
201 }
202
ParallelFor(int64 total,int64 cost_per_unit,const std::function<void (int64,int64)> & fn)203 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
204 const std::function<void(int64, int64)>& fn) {
205 CHECK_GE(total, 0);
206 CHECK_EQ(total, (int64)(Eigen::Index)total);
207 threadpool_device_->parallelFor(
208 total, Eigen::TensorOpCost(0, 0, cost_per_unit),
209 [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
210 }
211
ParallelForWithWorkerId(int64 total,int64 cost_per_unit,const std::function<void (int64,int64,int)> & fn)212 void ThreadPool::ParallelForWithWorkerId(
213 int64 total, int64 cost_per_unit,
214 const std::function<void(int64, int64, int)>& fn) {
215 CHECK_GE(total, 0);
216 CHECK_EQ(total, (int64)(Eigen::Index)total);
217
218 threadpool_device_->parallelFor(total,
219 Eigen::TensorOpCost(0, 0, cost_per_unit),
220 [this, &fn](int64 start, int64 limit) {
221 // ParallelFor may use the current thread to
222 // do some work synchronously. When calling
223 // CurrentThreadId() from outside of the
224 // thread pool, we get -1, so we can shift
225 // every id up by 1.
226 int id = CurrentThreadId() + 1;
227 fn(start, limit, id);
228 });
229 }
230
ParallelForWithWorkerId(int64 total,const SchedulingParams & scheduling_params,const std::function<void (int64,int64,int)> & fn)231 void ThreadPool::ParallelForWithWorkerId(
232 int64 total, const SchedulingParams& scheduling_params,
233 const std::function<void(int64, int64, int)>& fn) {
234 ParallelFor(total, scheduling_params, [this, &fn](int64 start, int64 limit) {
235 // We may use the current thread to do some work synchronously.
236 // When calling CurrentThreadId() from outside of the thread
237 // pool, we get -1, so we can shift every id up by 1.
238 int id = CurrentThreadId() + 1;
239 fn(start, limit, id);
240 });
241 }
242
NumThreads() const243 int ThreadPool::NumThreads() const {
244 return underlying_threadpool_->NumThreads();
245 }
246
CurrentThreadId() const247 int ThreadPool::CurrentThreadId() const {
248 return underlying_threadpool_->CurrentThreadId();
249 }
250
ScheduleWithHint(std::function<void ()> fn,int start,int limit)251 void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
252 int limit) {
253 underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit);
254 }
255
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)256 void ThreadPool::SetStealPartitions(
257 const std::vector<std::pair<unsigned, unsigned>>& partitions) {
258 // ThreadPool::SetStealPartitions is only called in the constructor of
259 // RunHandlerPool::Impl, which currently instantiates ThreadPool using a
260 // constructor that does not take user_threadpool. Thus we assume
261 // eigen_threadpool_ is not null here.
262 DCHECK(eigen_threadpool_ != nullptr);
263 eigen_threadpool_->SetStealPartitions(partitions);
264 }
265
AsEigenThreadPool() const266 Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
267 DCHECK(underlying_threadpool_ != nullptr);
268 return underlying_threadpool_;
269 }
270 } // namespace thread
271 } // namespace tensorflow
272