• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/threadpool.h"
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "absl/types/optional.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/platform/blocking_counter.h"
23 #include "tensorflow/core/platform/context.h"
24 #include "tensorflow/core/platform/denormal.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/platform/numa.h"
28 #include "tensorflow/core/platform/setround.h"
29 #include "tensorflow/core/platform/tracing.h"
30 
31 namespace tensorflow {
32 namespace thread {
33 
34 struct EigenEnvironment {
35   typedef Thread EnvThread;
36   struct TaskImpl {
37     std::function<void()> f;
38     Context context;
39     uint64 trace_id;
40   };
41   struct Task {
42     std::unique_ptr<TaskImpl> f;
43   };
44 
45   Env* const env_;
46   const ThreadOptions thread_options_;
47   const string name_;
48 
EigenEnvironmenttensorflow::thread::EigenEnvironment49   EigenEnvironment(Env* env, const ThreadOptions& thread_options,
50                    const string& name)
51       : env_(env), thread_options_(thread_options), name_(name) {}
52 
CreateThreadtensorflow::thread::EigenEnvironment53   EnvThread* CreateThread(std::function<void()> f) {
54     return env_->StartThread(thread_options_, name_, [=]() {
55       // Set the processor flag to flush denormals to zero.
56       port::ScopedFlushDenormal flush;
57       // Set the processor rounding mode to ROUND TO NEAREST.
58       port::ScopedSetRound round(FE_TONEAREST);
59       if (thread_options_.numa_node != port::kNUMANoAffinity) {
60         port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
61       }
62       f();
63     });
64   }
65 
CreateTasktensorflow::thread::EigenEnvironment66   Task CreateTask(std::function<void()> f) {
67     uint64 id = 0;
68     if (tracing::EventCollector::IsEnabled()) {
69       id = tracing::GetUniqueArg();
70       tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
71     }
72     return Task{
73         std::unique_ptr<TaskImpl>(new TaskImpl{
74             std::move(f),
75             Context(ContextKind::kThread),
76             id,
77         }),
78     };
79   }
80 
ExecuteTasktensorflow::thread::EigenEnvironment81   void ExecuteTask(const Task& t) {
82     WithContext wc(t.f->context);
83     tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
84                                  t.f->trace_id);
85     t.f->f();
86   }
87 };
88 
ThreadPool(Env * env,const string & name,int num_threads)89 ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
90     : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
91 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads)92 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
93                        const string& name, int num_threads)
94     : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
95 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,Eigen::Allocator * allocator)96 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
97                        const string& name, int num_threads,
98                        bool low_latency_hint, Eigen::Allocator* allocator) {
99   CHECK_GE(num_threads, 1);
100   eigen_threadpool_.reset(new Eigen::ThreadPoolTempl<EigenEnvironment>(
101       num_threads, low_latency_hint,
102       EigenEnvironment(env, thread_options, "tf_" + name)));
103   underlying_threadpool_ = eigen_threadpool_.get();
104   threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_,
105                                                        num_threads, allocator));
106 }
107 
ThreadPool(thread::ThreadPoolInterface * user_threadpool)108 ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) {
109   underlying_threadpool_ = user_threadpool;
110   threadpool_device_.reset(new Eigen::ThreadPoolDevice(
111       underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr));
112 }
113 
~ThreadPool()114 ThreadPool::~ThreadPool() {}
115 
Schedule(std::function<void ()> fn)116 void ThreadPool::Schedule(std::function<void()> fn) {
117   CHECK(fn != nullptr);
118   underlying_threadpool_->Schedule(std::move(fn));
119 }
120 
NumShardsUsedByFixedBlockSizeScheduling(const int64 total,const int64 block_size)121 int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(
122     const int64 total, const int64 block_size) {
123   if (block_size <= 0 || total <= 1 || total <= block_size ||
124       NumThreads() == 1) {
125     return 1;
126   }
127   return (total + block_size - 1) / block_size;
128 }
129 
NumShardsUsedByTransformRangeConcurrently(const int64 block_size,const int64 total)130 int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
131     const int64 block_size, const int64 total) {
132   return NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
133 }
134 
ParallelFor(int64 total,const SchedulingParams & scheduling_params,const std::function<void (int64,int64)> & fn)135 void ThreadPool::ParallelFor(int64 total,
136                              const SchedulingParams& scheduling_params,
137                              const std::function<void(int64, int64)>& fn) {
138   switch (scheduling_params.strategy()) {
139     case SchedulingStrategy::kAdaptive: {
140       if (scheduling_params.cost_per_unit().has_value()) {
141         ParallelFor(total, *scheduling_params.cost_per_unit(), fn);
142       }
143       break;
144     }
145     case SchedulingStrategy::kFixedBlockSize: {
146       if (scheduling_params.block_size().has_value()) {
147         ParallelForFixedBlockSizeScheduling(
148             total, *scheduling_params.block_size(), fn);
149       }
150       break;
151     }
152   }
153 }
154 
TransformRangeConcurrently(const int64 block_size,const int64 total,const std::function<void (int64,int64)> & fn)155 void ThreadPool::TransformRangeConcurrently(
156     const int64 block_size, const int64 total,
157     const std::function<void(int64, int64)>& fn) {
158   ParallelFor(total,
159               SchedulingParams(SchedulingStrategy::kFixedBlockSize,
160                                absl::nullopt /* cost_per_unit */, block_size),
161               fn);
162 }
163 
164 // This functionality is similar to parallelFor, except that reasoning about
165 // the number of shards used is significantly easier.
ParallelForFixedBlockSizeScheduling(const int64 total,const int64 block_size,const std::function<void (int64,int64)> & fn)166 void ThreadPool::ParallelForFixedBlockSizeScheduling(
167     const int64 total, const int64 block_size,
168     const std::function<void(int64, int64)>& fn) {
169   const int num_shards_used =
170       NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
171   if (num_shards_used == 1) {
172     fn(0, total);
173     return;
174   }
175 
176   // Adapted from Eigen's parallelFor implementation.
177   BlockingCounter counter(num_shards_used);
178   std::function<void(int64, int64)> handle_range =
179       [=, &handle_range, &counter, &fn](int64 first, int64 last) {
180         while (last - first > block_size) {
181           // Find something near the midpoint which is a multiple of block size.
182           const int64 mid = first + ((last - first) / 2 + block_size - 1) /
183                                         block_size * block_size;
184           Schedule([=, &handle_range]() { handle_range(mid, last); });
185           last = mid;
186         }
187         // Single block or less, execute directly.
188         fn(first, last);
189         counter.DecrementCount();  // The shard is done.
190       };
191   if (num_shards_used <= NumThreads()) {
192     // Avoid a thread hop by running the root of the tree and one block on the
193     // main thread.
194     handle_range(0, total);
195   } else {
196     // Execute the root in the thread pool to avoid running work on more than
197     // numThreads() threads.
198     Schedule([=, &handle_range]() { handle_range(0, total); });
199   }
200   counter.Wait();
201 }
202 
ParallelFor(int64 total,int64 cost_per_unit,const std::function<void (int64,int64)> & fn)203 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
204                              const std::function<void(int64, int64)>& fn) {
205   CHECK_GE(total, 0);
206   CHECK_EQ(total, (int64)(Eigen::Index)total);
207   threadpool_device_->parallelFor(
208       total, Eigen::TensorOpCost(0, 0, cost_per_unit),
209       [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
210 }
211 
ParallelForWithWorkerId(int64 total,int64 cost_per_unit,const std::function<void (int64,int64,int)> & fn)212 void ThreadPool::ParallelForWithWorkerId(
213     int64 total, int64 cost_per_unit,
214     const std::function<void(int64, int64, int)>& fn) {
215   CHECK_GE(total, 0);
216   CHECK_EQ(total, (int64)(Eigen::Index)total);
217 
218   threadpool_device_->parallelFor(total,
219                                   Eigen::TensorOpCost(0, 0, cost_per_unit),
220                                   [this, &fn](int64 start, int64 limit) {
221                                     // ParallelFor may use the current thread to
222                                     // do some work synchronously. When calling
223                                     // CurrentThreadId() from outside of the
224                                     // thread pool, we get -1, so we can shift
225                                     // every id up by 1.
226                                     int id = CurrentThreadId() + 1;
227                                     fn(start, limit, id);
228                                   });
229 }
230 
ParallelForWithWorkerId(int64 total,const SchedulingParams & scheduling_params,const std::function<void (int64,int64,int)> & fn)231 void ThreadPool::ParallelForWithWorkerId(
232     int64 total, const SchedulingParams& scheduling_params,
233     const std::function<void(int64, int64, int)>& fn) {
234   ParallelFor(total, scheduling_params, [this, &fn](int64 start, int64 limit) {
235     // We may use the current thread to do some work synchronously.
236     // When calling CurrentThreadId() from outside of the thread
237     // pool, we get -1, so we can shift every id up by 1.
238     int id = CurrentThreadId() + 1;
239     fn(start, limit, id);
240   });
241 }
242 
NumThreads() const243 int ThreadPool::NumThreads() const {
244   return underlying_threadpool_->NumThreads();
245 }
246 
CurrentThreadId() const247 int ThreadPool::CurrentThreadId() const {
248   return underlying_threadpool_->CurrentThreadId();
249 }
250 
ScheduleWithHint(std::function<void ()> fn,int start,int limit)251 void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
252                                   int limit) {
253   underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit);
254 }
255 
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)256 void ThreadPool::SetStealPartitions(
257     const std::vector<std::pair<unsigned, unsigned>>& partitions) {
258   // ThreadPool::SetStealPartitions is only called in the constructor of
259   // RunHandlerPool::Impl, which currently instantiates ThreadPool using a
260   // constructor that does not take user_threadpool. Thus we assume
261   // eigen_threadpool_ is not null here.
262   DCHECK(eigen_threadpool_ != nullptr);
263   eigen_threadpool_->SetStealPartitions(partitions);
264 }
265 
AsEigenThreadPool() const266 Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
267   DCHECK(underlying_threadpool_ != nullptr);
268   return underlying_threadpool_;
269 }
270 }  // namespace thread
271 }  // namespace tensorflow
272