• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/core/threadpool.h"
17 
18 #define EIGEN_USE_THREADS
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/lib/core/blocking_counter.h"
21 #include "tensorflow/core/platform/context.h"
22 #include "tensorflow/core/platform/denormal.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/numa.h"
26 #include "tensorflow/core/platform/setround.h"
27 #include "tensorflow/core/platform/tracing.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 namespace thread {
32 
33 struct EigenEnvironment {
34   typedef Thread EnvThread;
35   struct TaskImpl {
36     std::function<void()> f;
37     Context context;
38     uint64 trace_id;
39   };
40   struct Task {
41     std::unique_ptr<TaskImpl> f;
42   };
43 
44   Env* const env_;
45   const ThreadOptions thread_options_;
46   const string name_;
47 
EigenEnvironmenttensorflow::thread::EigenEnvironment48   EigenEnvironment(Env* env, const ThreadOptions& thread_options,
49                    const string& name)
50       : env_(env), thread_options_(thread_options), name_(name) {}
51 
CreateThreadtensorflow::thread::EigenEnvironment52   EnvThread* CreateThread(std::function<void()> f) {
53     return env_->StartThread(thread_options_, name_, [=]() {
54       // Set the processor flag to flush denormals to zero.
55       port::ScopedFlushDenormal flush;
56       // Set the processor rounding mode to ROUND TO NEAREST.
57       port::ScopedSetRound round(FE_TONEAREST);
58       if (thread_options_.numa_node != port::kNUMANoAffinity) {
59         port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
60       }
61       f();
62     });
63   }
64 
CreateTasktensorflow::thread::EigenEnvironment65   Task CreateTask(std::function<void()> f) {
66     uint64 id = 0;
67     if (tracing::EventCollector::IsEnabled()) {
68       id = tracing::GetUniqueArg();
69       tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
70     }
71     return Task{
72         std::unique_ptr<TaskImpl>(new TaskImpl{
73             std::move(f),
74             Context(ContextKind::kThread),
75             id,
76         }),
77     };
78   }
79 
ExecuteTasktensorflow::thread::EigenEnvironment80   void ExecuteTask(const Task& t) {
81     WithContext wc(t.f->context);
82     tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
83                                  t.f->trace_id);
84     t.f->f();
85   }
86 };
87 
88 struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> {
Impltensorflow::thread::ThreadPool::Impl89   Impl(Env* env, const ThreadOptions& thread_options, const string& name,
90        int num_threads, bool low_latency_hint, Eigen::Allocator* allocator)
91       : Eigen::ThreadPoolTempl<EigenEnvironment>(
92             num_threads, low_latency_hint,
93             EigenEnvironment(env, thread_options, name)),
94         allocator_(allocator) {}
95 
ParallelFortensorflow::thread::ThreadPool::Impl96   void ParallelFor(int64 total, int64 cost_per_unit,
97                    std::function<void(int64, int64)> fn) {
98     CHECK_GE(total, 0);
99     CHECK_EQ(total, (int64)(Eigen::Index)total);
100     Eigen::ThreadPoolDevice device(this, this->NumThreads(), allocator_);
101     device.parallelFor(
102         total, Eigen::TensorOpCost(0, 0, cost_per_unit),
103         [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
104   }
105 
106   Eigen::Allocator* allocator_;
107 };
108 
ThreadPool(Env * env,const string & name,int num_threads)109 ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
110     : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
111 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads)112 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
113                        const string& name, int num_threads)
114     : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
115 
ThreadPool(Env * env,const ThreadOptions & thread_options,const string & name,int num_threads,bool low_latency_hint,Eigen::Allocator * allocator)116 ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
117                        const string& name, int num_threads,
118                        bool low_latency_hint, Eigen::Allocator* allocator) {
119   CHECK_GE(num_threads, 1);
120   impl_.reset(new ThreadPool::Impl(env, thread_options, "tf_" + name,
121                                    num_threads, low_latency_hint, allocator));
122 }
123 
~ThreadPool()124 ThreadPool::~ThreadPool() {}
125 
Schedule(std::function<void ()> fn)126 void ThreadPool::Schedule(std::function<void()> fn) {
127   CHECK(fn != nullptr);
128   impl_->Schedule(std::move(fn));
129 }
130 
NumShardsUsedByTransformRangeConcurrently(const int64 block_size,const int64 total)131 int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
132     const int64 block_size, const int64 total) {
133   if (block_size <= 0 || total <= 1 || total <= block_size ||
134       NumThreads() == 1) {
135     return 1;
136   }
137   return (total + block_size - 1) / block_size;
138 }
139 
140 // This functionality is similar to parallelFor, except that reasoning about
141 // the number of shards used is significantly easier.
TransformRangeConcurrently(const int64 block_size,const int64 total,const std::function<void (int64,int64)> & fn)142 void ThreadPool::TransformRangeConcurrently(
143     const int64 block_size, const int64 total,
144     const std::function<void(int64, int64)>& fn) {
145   const int num_shards_used =
146       NumShardsUsedByTransformRangeConcurrently(block_size, total);
147   if (num_shards_used == 1) {
148     fn(0, total);
149     return;
150   }
151 
152   // Adapted from Eigen's parallelFor implementation.
153   BlockingCounter counter(num_shards_used);
154   std::function<void(int64, int64)> handle_range =
155       [=, &handle_range, &counter, &fn](int64 first, int64 last) {
156         while (last - first > block_size) {
157           // Find something near the midpoint which is a multiple of block size.
158           const int64 mid = first + ((last - first) / 2 + block_size - 1) /
159                                         block_size * block_size;
160           Schedule([=, &handle_range]() { handle_range(mid, last); });
161           last = mid;
162         }
163         // Single block or less, execute directly.
164         fn(first, last);
165         counter.DecrementCount();  // The shard is done.
166       };
167   if (num_shards_used <= NumThreads()) {
168     // Avoid a thread hop by running the root of the tree and one block on the
169     // main thread.
170     handle_range(0, total);
171   } else {
172     // Execute the root in the thread pool to avoid running work on more than
173     // numThreads() threads.
174     Schedule([=, &handle_range]() { handle_range(0, total); });
175   }
176   counter.Wait();
177 }
178 
ParallelFor(int64 total,int64 cost_per_unit,std::function<void (int64,int64)> fn)179 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
180                              std::function<void(int64, int64)> fn) {
181   impl_->ParallelFor(total, cost_per_unit, std::move(fn));
182 }
183 
ParallelForWithWorkerId(int64 total,int64 cost_per_unit,const std::function<void (int64,int64,int)> & fn)184 void ThreadPool::ParallelForWithWorkerId(
185     int64 total, int64 cost_per_unit,
186     const std::function<void(int64, int64, int)>& fn) {
187   impl_->ParallelFor(total, cost_per_unit,
188                      [this, &fn](int64 start, int64 limit) {
189                        // ParallelFor may use the current thread to do some
190                        // work synchronously. When calling CurrentThreadId()
191                        // from outside of the thread pool, we get -1, so we can
192                        // shift every id up by 1.
193                        int id = CurrentThreadId() + 1;
194                        fn(start, limit, id);
195                      });
196 }
197 
NumThreads() const198 int ThreadPool::NumThreads() const { return impl_->NumThreads(); }
199 
CurrentThreadId() const200 int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); }
201 
ScheduleWithHint(std::function<void ()> fn,int start,int limit)202 void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
203                                   int limit) {
204   impl_->ScheduleWithHint(std::move(fn), start, limit);
205 }
206 
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)207 void ThreadPool::SetStealPartitions(
208     const std::vector<std::pair<unsigned, unsigned>>& partitions) {
209   impl_->SetStealPartitions(partitions);
210 }
211 }  // namespace thread
212 }  // namespace tensorflow
213