• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/framework/run_handler.h"
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/run_handler_util.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/util/ptr_util.h"
24 
25 namespace tensorflow {
26 
27 // Contains the concrete implementation of the RunHandler.
28 // Externally visible RunHandler class simply forwards the work to this one.
29 class RunHandler::Impl {
30  public:
Impl(RunHandlerPool::Impl * pool_impl)31   explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
32     Reset();
33   }
34 
~Impl()35   ~Impl() {}
36 
set_inter_op_scheduling_range(std::uint_fast32_t start,std::uint_fast32_t limit)37   void set_inter_op_scheduling_range(std::uint_fast32_t start,
38                                      std::uint_fast32_t limit) {
39     inter_op_scheduling_range_.store(EncodePartition(start, limit),
40                                      std::memory_order_release);
41   }
42 
inter_op_scheduling_range() const43   std::uint_fast32_t inter_op_scheduling_range() const {
44     return inter_op_scheduling_range_.load(std::memory_order_acquire);
45   }
46 
47   // Stores now time (in microseconds) since unix epoch when the handler is
48   // requested via RunHandlerPool::Get().
start_time_us() const49   uint64 start_time_us() const { return start_time_us_; }
50 
51   void ScheduleInterOpClosure(std::function<void()> fn);
52 
53   void Reset();
54 
pool_impl()55   RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
56 
57  private:
58   // Encoding/decoding logic for storing [start, limit) into a single
59   // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
60   const int kMaxPartitionBits = 16;
61   const int kMaxThreads = 1 << kMaxPartitionBits;
62 
EncodePartition(std::uint_fast32_t start,std::uint_fast32_t limit)63   std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
64                                      std::uint_fast32_t limit) {
65     return (start << kMaxPartitionBits) | limit;
66   }
67 
DecodePartition(std::uint_fast32_t val,std::uint_fast32_t * start,std::uint_fast32_t * limit)68   void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
69                        std::uint_fast32_t* limit) {
70     *limit = val & (kMaxThreads - 1);
71     val >>= kMaxPartitionBits;
72     *start = val;
73   }
74 
75   std::atomic_uint_fast32_t inter_op_scheduling_range_;
76   RunHandlerPool::Impl* pool_impl_;  // NOT OWNED.
77   uint64 start_time_us_;
78 };
79 
80 // Contains shared state across all run handlers present in the pool. Also
81 // responsible for pool management decisions.
82 // This class is thread safe.
83 class RunHandlerPool::Impl {
84  public:
Impl(int num_inter_op_threads)85   explicit Impl(int num_inter_op_threads)
86       : max_handlers_(128),
87         inter_op_thread_pool_(new thread::ThreadPool(
88             Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
89         iterations_(0) {
90     VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
91     for (int i = 0; i < max_handlers_; ++i) {
92       handlers_.emplace_back(new RunHandler::Impl(this));
93       free_handlers_.push_back(handlers_.back().get());
94     }
95     // Set steal partitions to a fixed size steal domain of size 6 = 2 *
96     // kMinThreadsPerRequest.
97     std::vector<std::pair<unsigned, unsigned>> steal_partitions(
98         num_inter_op_threads);
99     int kStealDomainSize = std::min(6, num_inter_op_threads);
100     unsigned steal_start = 0, steal_end = kStealDomainSize;
101     for (int i = 0; i < num_inter_op_threads; ++i) {
102       if (i > steal_start) {
103         if (steal_end + kStealDomainSize < num_inter_op_threads) {
104           steal_start = steal_end;
105           steal_end += kStealDomainSize;
106         } else {
107           steal_end = num_inter_op_threads;
108           steal_start = steal_end - kStealDomainSize;
109         }
110       }
111       steal_partitions[i] = std::make_pair(steal_start, steal_end);
112       VLOG(1) << "Steal partition i: " << i << " steal_start: " << steal_start
113               << " steal_end: " << steal_end;
114     }
115     inter_op_thread_pool_->SetStealPartitions(steal_partitions);
116   }
117 
~Impl()118   ~Impl() {
119     // Sanity check that all handlers have been returned back to the pool before
120     // destruction.
121     DCHECK_EQ(handlers_.size(), max_handlers_);
122     DCHECK_EQ(free_handlers_.size(), handlers_.size());
123     DCHECK_EQ(sorted_active_handlers_.size(), 0);
124   }
125 
inter_op_thread_pool() const126   thread::ThreadPool* inter_op_thread_pool() const {
127     return inter_op_thread_pool_.get();
128   }
129 
Get()130   std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
131     mutex_lock l(mu_);
132     while (free_handlers_.empty()) {
133       one_handler_free_.wait(l);
134     }
135     // Remove the last entry from free_handlers_ and add to the end of
136     // sorted_active_handlers_.
137     auto* handler_impl = free_handlers_.back();
138     handler_impl->Reset();
139     // Sortedness isn't violated if we simply add at the end of the list, since
140     // handlers are expected to be obtained in increasing order of time.
141     sorted_active_handlers_.push_back(handler_impl);
142     DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
143     free_handlers_.pop_back();
144 
145     RecomputePoolStatsLocked();
146     return WrapUnique<RunHandler>(new RunHandler(handler_impl));
147   }
148 
ReleaseHandler(RunHandler::Impl * handler)149   void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
150     {
151       mutex_lock l(mu_);
152       DCHECK_GT(sorted_active_handlers_.size(), 0);
153 
154       uint64 now = tensorflow::Env::Default()->NowMicros();
155       double elapsed = (now - handler->start_time_us()) / 1000.0;
156       time_hist_.Add(elapsed);
157 
158       // Erase from and update sorted_active_handlers_. Add it to the end of
159       // free_handlers_.
160       auto iter = std::find(sorted_active_handlers_.begin(),
161                             sorted_active_handlers_.end(), handler);
162       DCHECK(iter != sorted_active_handlers_.end())
163           << "Unexpected handler: " << handler
164           << " is being requested for release";
165 
166       // Remove this handler from this list and add it to the list of free
167       // handlers.
168       sorted_active_handlers_.erase(iter);
169       free_handlers_.push_back(handler);
170       DCHECK_LE(free_handlers_.size(), max_handlers_);
171 
172       RecomputePoolStatsLocked();
173     }
174     one_handler_free_.notify_one();
175   }
176 
177  private:
178   void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
179 
180   // Maximum number of handlers pre-created during pool construction time. The
181   // number has been chosen expecting each handler might at least want 1
182   // inter-op thread for execution (during compute intensive workloads like
183   // inference).
184   const int max_handlers_;
185 
186   // Thread safe part.
187   const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
188 
189   // Thread compatible part used only by lock under RunHandlerPool.
190   // Handlers are sorted by start time.
191   std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
192   std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
193   std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
194   // Histogram of elapsed runtime of every handler (in ms).
195   histogram::Histogram time_hist_ GUARDED_BY(mu_);
196   std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
197   std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
198   int64 iterations_ GUARDED_BY(mu_);
199   condition_variable one_handler_free_;
200   mutex mu_;
201 };
202 
RecomputePoolStatsLocked()203 void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
204   int num_active_requests = sorted_active_handlers_.size();
205   if (num_active_requests == 0) return;
206 
207   int num_threads = inter_op_thread_pool_->NumThreads();
208 
209   inter_op_start_.resize(num_active_requests);
210   inter_op_limit_.resize(num_active_requests);
211 
212   const int kMinThreadsPerRequest = 3;
213   ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
214                                  kMinThreadsPerRequest, &inter_op_start_,
215                                  &inter_op_limit_);
216 
217   for (int i = 0; i < num_active_requests; ++i) {
218     sorted_active_handlers_[i]->set_inter_op_scheduling_range(
219         inter_op_start_[i], inter_op_limit_[i]);
220   }
221 
222   if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
223     VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
224     VLOG(1) << "Active session runs: " << num_active_requests;
225     uint64 now = tensorflow::Env::Default()->NowMicros();
226     string ranges_str = "";
227     string times_str = "";
228     for (int i = 0; i < num_active_requests; ++i) {
229       if (i > 0) {
230         times_str += " ";
231         ranges_str += " ";
232       }
233 
234       times_str += strings::StrCat(
235           (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
236       ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
237                                     inter_op_limit_[i], ")");
238     }
239     VLOG(1) << "Elapsed times are: " << times_str;
240     VLOG(1) << "Ranges are: " << ranges_str;
241   }
242 }
243 
ScheduleInterOpClosure(std::function<void ()> fn)244 void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
245   std::uint_fast32_t start = 0, limit = 0;
246   DecodePartition(inter_op_scheduling_range(), &start, &limit);
247   DCHECK_LT(start, limit);
248   pool_impl_->inter_op_thread_pool()->ScheduleWithHint(std::move(fn), start,
249                                                        limit);
250 }
251 
Reset()252 void RunHandler::Impl::Reset() {
253   set_inter_op_scheduling_range(
254       0, pool_impl_->inter_op_thread_pool()->NumThreads());
255   start_time_us_ = tensorflow::Env::Default()->NowMicros();
256 }
257 
RunHandlerPool(int num_inter_op_threads)258 RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
259     : impl_(new Impl(num_inter_op_threads)) {}
260 
~RunHandlerPool()261 RunHandlerPool::~RunHandlerPool() {}
262 
Get()263 std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
264 
RunHandler(Impl * impl)265 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
266 
ScheduleInterOpClosure(std::function<void ()> fn)267 void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
268   impl_->ScheduleInterOpClosure(std::move(fn));
269 }
270 
~RunHandler()271 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
272 }  // namespace tensorflow
273