1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/framework/run_handler.h"
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/run_handler_util.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/util/ptr_util.h"
24
25 namespace tensorflow {
26
27 // Contains the concrete implementation of the RunHandler.
28 // Externally visible RunHandler class simply forwards the work to this one.
29 class RunHandler::Impl {
30 public:
Impl(RunHandlerPool::Impl * pool_impl)31 explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
32 Reset();
33 }
34
~Impl()35 ~Impl() {}
36
set_inter_op_scheduling_range(std::uint_fast32_t start,std::uint_fast32_t limit)37 void set_inter_op_scheduling_range(std::uint_fast32_t start,
38 std::uint_fast32_t limit) {
39 inter_op_scheduling_range_.store(EncodePartition(start, limit),
40 std::memory_order_release);
41 }
42
inter_op_scheduling_range() const43 std::uint_fast32_t inter_op_scheduling_range() const {
44 return inter_op_scheduling_range_.load(std::memory_order_acquire);
45 }
46
47 // Stores now time (in microseconds) since unix epoch when the handler is
48 // requested via RunHandlerPool::Get().
start_time_us() const49 uint64 start_time_us() const { return start_time_us_; }
50
51 void ScheduleInterOpClosure(std::function<void()> fn);
52
53 void Reset();
54
pool_impl()55 RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
56
57 private:
58 // Encoding/decoding logic for storing [start, limit) into a single
59 // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
60 const int kMaxPartitionBits = 16;
61 const int kMaxThreads = 1 << kMaxPartitionBits;
62
EncodePartition(std::uint_fast32_t start,std::uint_fast32_t limit)63 std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
64 std::uint_fast32_t limit) {
65 return (start << kMaxPartitionBits) | limit;
66 }
67
DecodePartition(std::uint_fast32_t val,std::uint_fast32_t * start,std::uint_fast32_t * limit)68 void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
69 std::uint_fast32_t* limit) {
70 *limit = val & (kMaxThreads - 1);
71 val >>= kMaxPartitionBits;
72 *start = val;
73 }
74
75 std::atomic_uint_fast32_t inter_op_scheduling_range_;
76 RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
77 uint64 start_time_us_;
78 };
79
80 // Contains shared state across all run handlers present in the pool. Also
81 // responsible for pool management decisions.
82 // This class is thread safe.
83 class RunHandlerPool::Impl {
84 public:
Impl(int num_inter_op_threads)85 explicit Impl(int num_inter_op_threads)
86 : max_handlers_(128),
87 inter_op_thread_pool_(new thread::ThreadPool(
88 Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
89 iterations_(0) {
90 VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
91 for (int i = 0; i < max_handlers_; ++i) {
92 handlers_.emplace_back(new RunHandler::Impl(this));
93 free_handlers_.push_back(handlers_.back().get());
94 }
95 // Set steal partitions to a fixed size steal domain of size 6 = 2 *
96 // kMinThreadsPerRequest.
97 std::vector<std::pair<unsigned, unsigned>> steal_partitions(
98 num_inter_op_threads);
99 int kStealDomainSize = std::min(6, num_inter_op_threads);
100 unsigned steal_start = 0, steal_end = kStealDomainSize;
101 for (int i = 0; i < num_inter_op_threads; ++i) {
102 if (i > steal_start) {
103 if (steal_end + kStealDomainSize < num_inter_op_threads) {
104 steal_start = steal_end;
105 steal_end += kStealDomainSize;
106 } else {
107 steal_end = num_inter_op_threads;
108 steal_start = steal_end - kStealDomainSize;
109 }
110 }
111 steal_partitions[i] = std::make_pair(steal_start, steal_end);
112 VLOG(1) << "Steal partition i: " << i << " steal_start: " << steal_start
113 << " steal_end: " << steal_end;
114 }
115 inter_op_thread_pool_->SetStealPartitions(steal_partitions);
116 }
117
~Impl()118 ~Impl() {
119 // Sanity check that all handlers have been returned back to the pool before
120 // destruction.
121 DCHECK_EQ(handlers_.size(), max_handlers_);
122 DCHECK_EQ(free_handlers_.size(), handlers_.size());
123 DCHECK_EQ(sorted_active_handlers_.size(), 0);
124 }
125
inter_op_thread_pool() const126 thread::ThreadPool* inter_op_thread_pool() const {
127 return inter_op_thread_pool_.get();
128 }
129
Get()130 std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
131 mutex_lock l(mu_);
132 while (free_handlers_.empty()) {
133 one_handler_free_.wait(l);
134 }
135 // Remove the last entry from free_handlers_ and add to the end of
136 // sorted_active_handlers_.
137 auto* handler_impl = free_handlers_.back();
138 handler_impl->Reset();
139 // Sortedness isn't violated if we simply add at the end of the list, since
140 // handlers are expected to be obtained in increasing order of time.
141 sorted_active_handlers_.push_back(handler_impl);
142 DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
143 free_handlers_.pop_back();
144
145 RecomputePoolStatsLocked();
146 return WrapUnique<RunHandler>(new RunHandler(handler_impl));
147 }
148
ReleaseHandler(RunHandler::Impl * handler)149 void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
150 {
151 mutex_lock l(mu_);
152 DCHECK_GT(sorted_active_handlers_.size(), 0);
153
154 uint64 now = tensorflow::Env::Default()->NowMicros();
155 double elapsed = (now - handler->start_time_us()) / 1000.0;
156 time_hist_.Add(elapsed);
157
158 // Erase from and update sorted_active_handlers_. Add it to the end of
159 // free_handlers_.
160 auto iter = std::find(sorted_active_handlers_.begin(),
161 sorted_active_handlers_.end(), handler);
162 DCHECK(iter != sorted_active_handlers_.end())
163 << "Unexpected handler: " << handler
164 << " is being requested for release";
165
166 // Remove this handler from this list and add it to the list of free
167 // handlers.
168 sorted_active_handlers_.erase(iter);
169 free_handlers_.push_back(handler);
170 DCHECK_LE(free_handlers_.size(), max_handlers_);
171
172 RecomputePoolStatsLocked();
173 }
174 one_handler_free_.notify_one();
175 }
176
177 private:
178 void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
179
180 // Maximum number of handlers pre-created during pool construction time. The
181 // number has been chosen expecting each handler might at least want 1
182 // inter-op thread for execution (during compute intensive workloads like
183 // inference).
184 const int max_handlers_;
185
186 // Thread safe part.
187 const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
188
189 // Thread compatible part used only by lock under RunHandlerPool.
190 // Handlers are sorted by start time.
191 std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
192 std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
193 std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
194 // Histogram of elapsed runtime of every handler (in ms).
195 histogram::Histogram time_hist_ GUARDED_BY(mu_);
196 std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
197 std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
198 int64 iterations_ GUARDED_BY(mu_);
199 condition_variable one_handler_free_;
200 mutex mu_;
201 };
202
RecomputePoolStatsLocked()203 void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
204 int num_active_requests = sorted_active_handlers_.size();
205 if (num_active_requests == 0) return;
206
207 int num_threads = inter_op_thread_pool_->NumThreads();
208
209 inter_op_start_.resize(num_active_requests);
210 inter_op_limit_.resize(num_active_requests);
211
212 const int kMinThreadsPerRequest = 3;
213 ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
214 kMinThreadsPerRequest, &inter_op_start_,
215 &inter_op_limit_);
216
217 for (int i = 0; i < num_active_requests; ++i) {
218 sorted_active_handlers_[i]->set_inter_op_scheduling_range(
219 inter_op_start_[i], inter_op_limit_[i]);
220 }
221
222 if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
223 VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
224 VLOG(1) << "Active session runs: " << num_active_requests;
225 uint64 now = tensorflow::Env::Default()->NowMicros();
226 string ranges_str = "";
227 string times_str = "";
228 for (int i = 0; i < num_active_requests; ++i) {
229 if (i > 0) {
230 times_str += " ";
231 ranges_str += " ";
232 }
233
234 times_str += strings::StrCat(
235 (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
236 ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
237 inter_op_limit_[i], ")");
238 }
239 VLOG(1) << "Elapsed times are: " << times_str;
240 VLOG(1) << "Ranges are: " << ranges_str;
241 }
242 }
243
ScheduleInterOpClosure(std::function<void ()> fn)244 void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
245 std::uint_fast32_t start = 0, limit = 0;
246 DecodePartition(inter_op_scheduling_range(), &start, &limit);
247 DCHECK_LT(start, limit);
248 pool_impl_->inter_op_thread_pool()->ScheduleWithHint(std::move(fn), start,
249 limit);
250 }
251
Reset()252 void RunHandler::Impl::Reset() {
253 set_inter_op_scheduling_range(
254 0, pool_impl_->inter_op_thread_pool()->NumThreads());
255 start_time_us_ = tensorflow::Env::Default()->NowMicros();
256 }
257
RunHandlerPool(int num_inter_op_threads)258 RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
259 : impl_(new Impl(num_inter_op_threads)) {}
260
~RunHandlerPool()261 RunHandlerPool::~RunHandlerPool() {}
262
Get()263 std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
264
RunHandler(Impl * impl)265 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
266
ScheduleInterOpClosure(std::function<void ()> fn)267 void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
268 impl_->ScheduleInterOpClosure(std::move(fn));
269 }
270
~RunHandler()271 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
272 } // namespace tensorflow
273