1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/core/threadpool.h"
17
18 #define EIGEN_USE_THREADS
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/lib/core/blocking_counter.h"
21 #include "tensorflow/core/platform/context.h"
22 #include "tensorflow/core/platform/denormal.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/numa.h"
26 #include "tensorflow/core/platform/setround.h"
27 #include "tensorflow/core/platform/tracing.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace tensorflow {
31 namespace thread {
32
33 struct EigenEnvironment {
34 typedef Thread EnvThread;
35 struct TaskImpl {
36 std::function<void()> f;
37 Context context;
38 uint64 trace_id;
39 };
40 struct Task {
41 std::unique_ptr<TaskImpl> f;
42 };
43
44 Env* const env_;
45 const ThreadOptions thread_options_;
46 const string name_;
47
EigenEnvironmenttensorflow::thread::EigenEnvironment48 EigenEnvironment(Env* env, const ThreadOptions& thread_options,
49 const string& name)
50 : env_(env), thread_options_(thread_options), name_(name) {}
51
CreateThreadtensorflow::thread::EigenEnvironment52 EnvThread* CreateThread(std::function<void()> f) {
53 return env_->StartThread(thread_options_, name_, [=]() {
54 // Set the processor flag to flush denormals to zero.
55 port::ScopedFlushDenormal flush;
56 // Set the processor rounding mode to ROUND TO NEAREST.
57 port::ScopedSetRound round(FE_TONEAREST);
58 if (thread_options_.numa_node != port::kNUMANoAffinity) {
59 port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
60 }
61 f();
62 });
63 }
64
CreateTasktensorflow::thread::EigenEnvironment65 Task CreateTask(std::function<void()> f) {
66 uint64 id = 0;
67 if (tracing::EventCollector::IsEnabled()) {
68 id = tracing::GetUniqueArg();
69 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
70 }
71 return Task{
72 std::unique_ptr<TaskImpl>(new TaskImpl{
73 std::move(f),
74 Context(ContextKind::kThread),
75 id,
76 }),
77 };
78 }
79
ExecuteTasktensorflow::thread::EigenEnvironment80 void ExecuteTask(const Task& t) {
81 WithContext wc(t.f->context);
82 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
83 t.f->trace_id);
84 t.f->f();
85 }
86 };
87
88 struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
Impltensorflow::thread::ThreadPool::Impl89 Impl(Env* env, const ThreadOptions& thread_options, const string& name,
90 int num_threads, bool low_latency_hint, Eigen::Allocator* allocator)
91 : Eigen::ThreadPoolTempl<EigenEnvironment>(
92 num_threads, low_latency_hint,
93 EigenEnvironment(env, thread_options, name)),
94 allocator_(allocator) {}
95
ParallelFortensorflow::thread::ThreadPool::Impl96 void ParallelFor(int64 total, int64 cost_per_unit,
97 std::function<void(int64, int64)> fn) {
98 CHECK_GE(total, 0);
99 CHECK_EQ(total, (int64)(Eigen::Index)total);
100 Eigen::ThreadPoolDevice device(this, this->NumThreads(), allocator_);
101 device.parallelFor(
102 total, Eigen::TensorOpCost(0, 0, cost_per_unit),
103 [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
104 }
105
106 Eigen::Allocator* allocator_;
107 };
108
ThreadPool(Env * env,const string & name,int num_threads)109 ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
110 : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
111
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads)112 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
113 const string& name, int num_threads)
114 : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
115
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,Eigen::Allocator * allocator)116 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
117 const string& name, int num_threads,
118 bool low_latency_hint, Eigen::Allocator* allocator) {
119 CHECK_GE(num_threads, 1);
120 impl_.reset(new ThreadPool::Impl(env, thread_options, "tf_" + name,
121 num_threads, low_latency_hint, allocator));
122 }
123
~ThreadPool()124 ThreadPool::~ThreadPool() {}
125
Schedule(std::function<void ()> fn)126 void ThreadPool::Schedule(std::function<void()> fn) {
127 CHECK(fn != nullptr);
128 impl_->Schedule(std::move(fn));
129 }
130
NumShardsUsedByTransformRangeConcurrently(const int64 block_size,const int64 total)131 int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
132 const int64 block_size, const int64 total) {
133 if (block_size <= 0 || total <= 1 || total <= block_size ||
134 NumThreads() == 1) {
135 return 1;
136 }
137 return (total + block_size - 1) / block_size;
138 }
139
140 // This functionality is similar to parallelFor, except that reasoning about
141 // the number of shards used is significantly easier.
TransformRangeConcurrently(const int64 block_size,const int64 total,const std::function<void (int64,int64)> & fn)142 void ThreadPool::TransformRangeConcurrently(
143 const int64 block_size, const int64 total,
144 const std::function<void(int64, int64)>& fn) {
145 const int num_shards_used =
146 NumShardsUsedByTransformRangeConcurrently(block_size, total);
147 if (num_shards_used == 1) {
148 fn(0, total);
149 return;
150 }
151
152 // Adapted from Eigen's parallelFor implementation.
153 BlockingCounter counter(num_shards_used);
154 std::function<void(int64, int64)> handle_range =
155 [=, &handle_range, &counter, &fn](int64 first, int64 last) {
156 while (last - first > block_size) {
157 // Find something near the midpoint which is a multiple of block size.
158 const int64 mid = first + ((last - first) / 2 + block_size - 1) /
159 block_size * block_size;
160 Schedule([=, &handle_range]() { handle_range(mid, last); });
161 last = mid;
162 }
163 // Single block or less, execute directly.
164 fn(first, last);
165 counter.DecrementCount(); // The shard is done.
166 };
167 if (num_shards_used <= NumThreads()) {
168 // Avoid a thread hop by running the root of the tree and one block on the
169 // main thread.
170 handle_range(0, total);
171 } else {
172 // Execute the root in the thread pool to avoid running work on more than
173 // numThreads() threads.
174 Schedule([=, &handle_range]() { handle_range(0, total); });
175 }
176 counter.Wait();
177 }
178
ParallelFor(int64 total,int64 cost_per_unit,std::function<void (int64,int64)> fn)179 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
180 std::function<void(int64, int64)> fn) {
181 impl_->ParallelFor(total, cost_per_unit, std::move(fn));
182 }
183
ParallelForWithWorkerId(int64 total,int64 cost_per_unit,const std::function<void (int64,int64,int)> & fn)184 void ThreadPool::ParallelForWithWorkerId(
185 int64 total, int64 cost_per_unit,
186 const std::function<void(int64, int64, int)>& fn) {
187 impl_->ParallelFor(total, cost_per_unit,
188 [this, &fn](int64 start, int64 limit) {
189 // ParallelFor may use the current thread to do some
190 // work synchronously. When calling CurrentThreadId()
191 // from outside of the thread pool, we get -1, so we can
192 // shift every id up by 1.
193 int id = CurrentThreadId() + 1;
194 fn(start, limit, id);
195 });
196 }
197
NumThreads() const198 int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
199
CurrentThreadId() const200 int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
201
ScheduleWithHint(std::function<void ()> fn,int start,int limit)202 void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
203 int limit) {
204 impl_->ScheduleWithHint(std::move(fn), start, limit);
205 }
206
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)207 void ThreadPool::SetStealPartitions(
208 const std::vector<std::pair<unsigned, unsigned>>& partitions) {
209 impl_->SetStealPartitions(partitions);
210 }
211 } // namespace thread
212 } // namespace tensorflow
213