1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/framework/run_handler.h"
19
20 #include <algorithm>
21 #include <cmath>
22 #include <list>
23 #include <memory>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/run_handler_util.h"
27 #include "tensorflow/core/lib/core/threadpool_interface.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/context.h"
30 #include "tensorflow/core/platform/denormal.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/numa.h"
33 #include "tensorflow/core/platform/setround.h"
34 #include "tensorflow/core/platform/tracing.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/util/ptr_util.h"
37
38 namespace tensorflow {
39 namespace {
40 // LINT.IfChange
41 static constexpr int32 kMaxConcurrentHandlers = 128;
42 // LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
43
44 typedef typename internal::RunHandlerEnvironment::Task Task;
45 typedef Eigen::RunQueue<Task, 1024> Queue;
46
47 } // namespace
48
49 namespace internal {
RunHandlerEnvironment(Env * env,const ThreadOptions & thread_options,const string & name)50 RunHandlerEnvironment::RunHandlerEnvironment(
51 Env* env, const ThreadOptions& thread_options, const string& name)
52 : env_(env), thread_options_(thread_options), name_(name) {}
53
CreateThread(std::function<void ()> f)54 RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
55 std::function<void()> f) {
56 return env_->StartThread(thread_options_, name_, [=]() {
57 // Set the processor flag to flush denormals to zero.
58 port::ScopedFlushDenormal flush;
59 // Set the processor rounding mode to ROUND TO NEAREST.
60 port::ScopedSetRound round(FE_TONEAREST);
61 if (thread_options_.numa_node != port::kNUMANoAffinity) {
62 port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
63 }
64 f();
65 });
66 }
67
CreateTask(std::function<void ()> f)68 RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
69 std::function<void()> f) {
70 uint64 id = 0;
71 if (tracing::EventCollector::IsEnabled()) {
72 id = tracing::GetUniqueArg();
73 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
74 }
75 return Task{
76 std::unique_ptr<TaskImpl>(new TaskImpl{
77 std::move(f),
78 Context(ContextKind::kThread),
79 id,
80 }),
81 };
82 }
83
ExecuteTask(const Task & t)84 void RunHandlerEnvironment::ExecuteTask(const Task& t) {
85 WithContext wc(t.f->context);
86 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
87 t.f->trace_id);
88 t.f->f();
89 }
90
WaitOnWaiter(Waiter * waiter,Waiter * queue_head,mutex * mutex,int max_sleep_micros)91 void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
92 int max_sleep_micros) {
93 {
94 mutex_lock l(*mutex);
95 CHECK_EQ(waiter->next, waiter); // Crash OK.
96 CHECK_EQ(waiter->prev, waiter); // Crash OK.
97
98 // Add waiter to the LIFO queue
99 waiter->prev = queue_head;
100 waiter->next = queue_head->next;
101 waiter->next->prev = waiter;
102 waiter->prev->next = waiter;
103 }
104 {
105 mutex_lock l(waiter->mu);
106 // Wait on the condition variable
107 waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108 }
109
110 mutex_lock l(*mutex);
111 // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
112 // to a notification we cannot conclude the waiter is not in the queue.
113 // This is due to the fact that a thread preempted right before notifying
114 // may resume after a waiter got re-added.
115 if (waiter->next != waiter) {
116 CHECK(waiter->prev != waiter); // Crash OK.
117 waiter->next->prev = waiter->prev;
118 waiter->prev->next = waiter->next;
119 waiter->next = waiter;
120 waiter->prev = waiter;
121 } else {
122 CHECK_EQ(waiter->prev, waiter); // Crash OK.
123 }
124 }
125
ThreadWorkSource()126 ThreadWorkSource::ThreadWorkSource()
127 : non_blocking_work_sharding_factor_(
128 static_cast<int32>(ParamFromEnvWithDefault(
129 "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
130 non_blocking_work_queues_(non_blocking_work_sharding_factor_),
131 blocking_inflight_(0),
132 non_blocking_inflight_(0),
133 traceme_id_(0),
134 version_(0),
135 sub_thread_pool_waiter_(nullptr) {
136 queue_waiters_.next = &queue_waiters_;
137 queue_waiters_.prev = &queue_waiters_;
138 for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
139 non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
140 }
141 }
142
~ThreadWorkSource()143 ThreadWorkSource::~ThreadWorkSource() {
144 for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
145 delete non_blocking_work_queues_[i];
146 }
147 }
148
EnqueueTask(Task t,bool is_blocking)149 Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
150 mutex* mu = nullptr;
151 Queue* task_queue = nullptr;
152 thread_local int64 closure_counter = 0;
153
154 if (!is_blocking) {
155 int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
156 task_queue = &(non_blocking_work_queues_[queue_index]->queue);
157 mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
158 } else {
159 task_queue = &blocking_work_queue_;
160 mu = &blocking_queue_op_mu_;
161 }
162
163 {
164 mutex_lock l(*mu);
165 // For a given queue, only one thread can call PushFront.
166 t = task_queue->PushFront(std::move(t));
167 }
168
169 Waiter* w = nullptr;
170 static const bool use_sub_thread_pool =
171 ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
172
173 Waiter* waiter_queue;
174 mutex* waiter_queue_mu;
175 if (use_sub_thread_pool) {
176 // When we use multiple sub thread pools, free threads wait on sub
177 // thread pool waiting queues. Wake up threads from sub thread waiting
178 // queues.
179 // The waiting queues are defined at RunHandlerPool.
180 // Get the waiter_queue and corresponding mutex. Note, the thread work
181 // source may change afterwards if a new request comes or an old request
182 // finishes.
183 tf_shared_lock lock(run_handler_waiter_mu_);
184 waiter_queue = sub_thread_pool_waiter_;
185 waiter_queue_mu = sub_thread_pool_waiter_mu_;
186 } else {
187 waiter_queue = &queue_waiters_;
188 waiter_queue_mu = &waiters_mu_;
189 }
190 {
191 mutex_lock l(*waiter_queue_mu);
192 if (waiter_queue->next != waiter_queue) {
193 // Remove waiter from the LIFO queue
194 w = waiter_queue->next;
195
196 CHECK(w->prev != w); // Crash OK.
197 CHECK(w->next != w); // Crash OK.
198
199 w->next->prev = w->prev;
200 w->prev->next = w->next;
201
202 // Use `w->next == &w` to indicate that the waiter has been removed
203 // from the queue.
204 w->next = w;
205 w->prev = w;
206 }
207 }
208 if (w != nullptr) {
209 // We call notify_one() without any locks, so we can miss notifications.
210 // The wake up logic is best effort and a thread will wake in short
211 // period of time in case a notification is missed.
212 w->cv.notify_one();
213 }
214 VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
215 << traceme_id_.load(std::memory_order_relaxed);
216 return t;
217 }
218
PopBlockingTask()219 Task ThreadWorkSource::PopBlockingTask() {
220 return blocking_work_queue_.PopBack();
221 }
222
PopNonBlockingTask(int start_index,bool search_from_all_queue)223 Task ThreadWorkSource::PopNonBlockingTask(int start_index,
224 bool search_from_all_queue) {
225 Task t;
226 unsigned sharding_factor = NonBlockingWorkShardingFactor();
227 for (unsigned j = 0; j < sharding_factor; ++j) {
228 t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
229 ->queue.PopBack();
230 if (t.f) {
231 return t;
232 }
233 if (!search_from_all_queue) {
234 break;
235 }
236 }
237 return t;
238 }
239
WaitForWork(int max_sleep_micros)240 void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
241 thread_local Waiter waiter;
242 WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
243 }
244
TaskQueueSize(bool is_blocking)245 int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
246 if (is_blocking) {
247 return blocking_work_queue_.Size();
248 } else {
249 unsigned total_size = 0;
250 for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
251 total_size += non_blocking_work_queues_[i]->queue.Size();
252 }
253 return total_size;
254 }
255 }
256
GetTracemeId()257 int64 ThreadWorkSource::GetTracemeId() {
258 return traceme_id_.load(std::memory_order_relaxed);
259 }
260
SetTracemeId(int64 value)261 void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
262
SetWaiter(uint64 version,Waiter * waiter,mutex * mutex)263 void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
264 {
265 tf_shared_lock lock(run_handler_waiter_mu_);
266 // Most of the request won't change sub pool for recomputation.
267 // Optimization for avoiding holding exclusive lock to reduce contention.
268 if (sub_thread_pool_waiter_ == waiter) {
269 return;
270 }
271 // If the current version is a newer version, no need to update.
272 if (version_ > version) {
273 return;
274 }
275 }
276
277 mutex_lock l(run_handler_waiter_mu_);
278 sub_thread_pool_waiter_ = waiter;
279 sub_thread_pool_waiter_mu_ = mutex;
280 version_ = version;
281 }
282
GetInflightTaskCount(bool is_blocking)283 int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
284 std::atomic<int64>* counter =
285 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
286 return counter->load(std::memory_order_relaxed);
287 }
288
IncrementInflightTaskCount(bool is_blocking)289 void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
290 std::atomic<int64>* counter =
291 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
292 counter->fetch_add(1, std::memory_order_relaxed);
293 }
294
DecrementInflightTaskCount(bool is_blocking)295 void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
296 std::atomic<int64>* counter =
297 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
298 counter->fetch_sub(1, std::memory_order_relaxed);
299 }
300
NonBlockingWorkShardingFactor()301 unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
302 return non_blocking_work_sharding_factor_;
303 }
304
ToString()305 std::string ThreadWorkSource::ToString() {
306 return strings::StrCat("traceme_id = ", GetTracemeId(),
307 ", inter queue size = ", TaskQueueSize(true),
308 ", inter inflight = ", GetInflightTaskCount(true),
309 ", intra queue size = ", TaskQueueSize(false),
310 ", intra inflight = ", GetInflightTaskCount(false));
311 }
312
RunHandlerThreadPool(int num_blocking_threads,int num_non_blocking_threads,Env * env,const ThreadOptions & thread_options,const string & name,Eigen::MaxSizeVector<mutex> * waiters_mu,Eigen::MaxSizeVector<Waiter> * queue_waiters)313 RunHandlerThreadPool::RunHandlerThreadPool(
314 int num_blocking_threads, int num_non_blocking_threads, Env* env,
315 const ThreadOptions& thread_options, const string& name,
316 Eigen::MaxSizeVector<mutex>* waiters_mu,
317 Eigen::MaxSizeVector<Waiter>* queue_waiters)
318 : num_threads_(num_blocking_threads + num_non_blocking_threads),
319 num_blocking_threads_(num_blocking_threads),
320 num_non_blocking_threads_(num_non_blocking_threads),
321 thread_data_(num_threads_),
322 env_(env, thread_options, name),
323 name_(name),
324 waiters_mu_(waiters_mu),
325 queue_waiters_(queue_waiters),
326 use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
327 "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
328 num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
329 "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
330 std::vector<int>({num_blocking_threads / 2,
331 num_blocking_threads - num_blocking_threads / 2}))),
332 sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
333 "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
334 std::vector<double>({0, 0.4}))),
335 sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
336 "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
337 std::vector<double>({0.4, 1}))) {
338 thread_data_.resize(num_threads_);
339 VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
340 << num_blocking_threads_ << " blocking threads and "
341 << num_non_blocking_threads_ << " non-blocking threads.";
342 }
343
~RunHandlerThreadPool()344 RunHandlerThreadPool::~RunHandlerThreadPool() {
345 VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
346
347 cancelled_ = true;
348 for (size_t i = 0; i < thread_data_.size(); ++i) {
349 {
350 mutex_lock l(thread_data_[i].mu);
351 thread_data_[i].sources_not_empty.notify_all();
352 }
353 thread_data_[i].thread.reset();
354 }
355 }
356
Start()357 void RunHandlerThreadPool::Start() {
358 cancelled_ = false;
359 int num_blocking_threads = num_blocking_threads_;
360 for (int i = 0; i < num_threads_; i++) {
361 int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
362 for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
363 if (i < num_threads_in_sub_thread_pool_[j]) {
364 sub_thread_pool_id = j;
365 break;
366 }
367 }
368 thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
369 thread_data_[i].thread.reset(
370 env_.CreateThread([this, i, num_blocking_threads]() {
371 WorkerLoop(i, i < num_blocking_threads);
372 }));
373 }
374 }
375
StartOneThreadForTesting()376 void RunHandlerThreadPool::StartOneThreadForTesting() {
377 cancelled_ = false;
378 thread_data_[0].sub_thread_pool_id = 0;
379 thread_data_[0].thread.reset(
380 env_.CreateThread([this]() { WorkerLoop(0, true); }));
381 }
382
AddWorkToQueue(ThreadWorkSource * tws,bool is_blocking,std::function<void ()> fn)383 void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
384 bool is_blocking,
385 std::function<void()> fn) {
386 Task t = env_.CreateTask(std::move(fn));
387 t = tws->EnqueueTask(std::move(t), is_blocking);
388 if (t.f) {
389 VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
390 << tws->GetTracemeId();
391 env_.ExecuteTask(t);
392 }
393 }
394
395 // TODO(donglin) Change the task steal order to be round-robin such that if
396 // an attempt to steal task from request i failed, then attempt to steal task
397 // from the next request in terms of the arrival time. This approach may
398 // provide better performance due to less lock retention. The drawback is that
399 // the profiler will be a bit harder to read.
SetThreadWorkSources(int tid,int start_request_idx,uint64 version,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources)400 void RunHandlerThreadPool::SetThreadWorkSources(
401 int tid, int start_request_idx, uint64 version,
402 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
403 mutex_lock l(thread_data_[tid].mu);
404 if (version > thread_data_[tid].new_version) {
405 thread_data_[tid].new_version = version;
406 } else {
407 // A newer version is already updated. No need to update.
408 return;
409 }
410 thread_data_[tid].new_thread_work_sources->resize(0);
411 if (use_sub_thread_pool_) {
412 for (int i = 0; i < thread_work_sources.size(); ++i) {
413 thread_data_[tid].new_thread_work_sources->emplace_back(
414 thread_work_sources[i]);
415 }
416 } else {
417 thread_data_[tid].new_thread_work_sources->emplace_back(
418 thread_work_sources[start_request_idx]);
419 // The number of shards for the queue. Threads in each shard will
420 // prioritize different thread_work_sources. Increase the number of shards
421 // could decrease the contention in the queue. For example, when
422 // num_shards == 1: thread_work_sources are ordered as start_request_idx,
423 // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
424 // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
425 // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
426 // 4... for the other half of the threads.
427 static const int num_shards =
428 ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
429 int token = tid % num_shards;
430 for (int i = 0; i < num_shards; ++i) {
431 for (int j = token; j < thread_work_sources.size(); j += num_shards) {
432 if (j != start_request_idx) {
433 thread_data_[tid].new_thread_work_sources->emplace_back(
434 thread_work_sources[j]);
435 }
436 }
437 token = (token + 1) % num_shards;
438 }
439 thread_data_[tid].sources_not_empty.notify_all();
440 }
441 }
442
GetPerThread()443 RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
444 thread_local RunHandlerThreadPool::PerThread per_thread_;
445 RunHandlerThreadPool::PerThread* pt = &per_thread_;
446 return pt;
447 }
448
CurrentThreadId() const449 int RunHandlerThreadPool::CurrentThreadId() const {
450 const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
451 if (pt->pool == this) {
452 return pt->thread_id;
453 } else {
454 return -1;
455 }
456 }
457
NumThreads() const458 int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
459
NumBlockingThreads() const460 int RunHandlerThreadPool::NumBlockingThreads() const {
461 return num_blocking_threads_;
462 }
463
NumNonBlockingThreads() const464 int RunHandlerThreadPool::NumNonBlockingThreads() const {
465 return num_non_blocking_threads_;
466 }
467
ThreadData()468 RunHandlerThreadPool::ThreadData::ThreadData()
469 : new_version(0),
470 current_index(0),
471 new_thread_work_sources(
472 new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
473 ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
474 kMaxConcurrentHandlers)))),
475 current_version(0),
476 current_thread_work_sources(
477 new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
478 ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
479 kMaxConcurrentHandlers)))) {}
480
FindTask(int searching_range_start,int searching_range_end,int thread_id,int sub_thread_pool_id,int max_blocking_inflight,bool may_steal_blocking_work,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources,bool * task_from_blocking_queue,ThreadWorkSource ** tws)481 Task RunHandlerThreadPool::FindTask(
482 int searching_range_start, int searching_range_end, int thread_id,
483 int sub_thread_pool_id, int max_blocking_inflight,
484 bool may_steal_blocking_work,
485 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
486 bool* task_from_blocking_queue, ThreadWorkSource** tws) {
487 Task t;
488 int current_index = thread_data_[thread_id].current_index;
489 *task_from_blocking_queue = false;
490
491 for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
492 if (current_index >= searching_range_end ||
493 current_index < searching_range_start) {
494 current_index = searching_range_start;
495 }
496 *tws = thread_work_sources[current_index];
497 ++current_index;
498
499 // For blocking thread, search for blocking tasks first.
500 if (may_steal_blocking_work &&
501 (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
502 t = (*tws)->PopBlockingTask();
503 if (t.f) {
504 *task_from_blocking_queue = true;
505 break;
506 }
507 }
508
509 // Search for non-blocking tasks.
510 t = (*tws)->PopNonBlockingTask(thread_id, true);
511 if (t.f) {
512 break;
513 }
514 }
515 thread_data_[thread_id].current_index = current_index;
516 return t;
517 }
518
519 // Main worker thread loop.
WorkerLoop(int thread_id,bool may_steal_blocking_work)520 void RunHandlerThreadPool::WorkerLoop(int thread_id,
521 bool may_steal_blocking_work) {
522 PerThread* pt = GetPerThread();
523 pt->pool = this;
524 pt->thread_id = thread_id;
525 static constexpr int32 kMaxBlockingInflight = 10;
526
527 while (!cancelled_) {
528 Task t;
529 ThreadWorkSource* tws = nullptr;
530 bool task_from_blocking_queue = true;
531 int sub_thread_pool_id;
532 // Get the current thread work sources.
533 {
534 mutex_lock l(thread_data_[thread_id].mu);
535 if (thread_data_[thread_id].current_version <
536 thread_data_[thread_id].new_version) {
537 thread_data_[thread_id].current_version =
538 thread_data_[thread_id].new_version;
539 thread_data_[thread_id].current_thread_work_sources.swap(
540 thread_data_[thread_id].new_thread_work_sources);
541 }
542 }
543 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
544 thread_data_[thread_id].current_thread_work_sources.get();
545 if (use_sub_thread_pool_) {
546 sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
547 int active_requests = thread_work_sources->size();
548 if (may_steal_blocking_work) {
549 // Each thread will first look for tasks from requests that belongs to
550 // its sub thread pool.
551 int search_range_start =
552 active_requests *
553 sub_thread_pool_start_request_percentage_[sub_thread_pool_id];
554 int search_range_end =
555 active_requests *
556 sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
557 search_range_end =
558 std::min(active_requests,
559 std::max(search_range_end, search_range_start + 1));
560
561 t = FindTask(search_range_start, search_range_end, thread_id,
562 sub_thread_pool_id, kMaxBlockingInflight,
563 /*may_steal_blocking_work=*/true, *thread_work_sources,
564 &task_from_blocking_queue, &tws);
565 if (!t.f) {
566 // Search from all requests if the thread cannot find tasks from
567 // requests that belong to its own sub thread pool.
568 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
569 kMaxBlockingInflight,
570 /*may_steal_blocking_work=*/true, *thread_work_sources,
571 &task_from_blocking_queue, &tws);
572 }
573 } else {
574 // For non-blocking threads, it will always search from all pending
575 // requests.
576 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
577 kMaxBlockingInflight,
578 /*may_steal_blocking_work=*/false, *thread_work_sources,
579 &task_from_blocking_queue, &tws);
580 }
581 } else {
582 // TODO(chaox): Refactor the following code to share the logic with
583 // FindTask.
584 for (int i = 0; i < thread_work_sources->size(); ++i) {
585 tws = (*thread_work_sources)[i];
586 // We want a smallish numbers of inter threads since
587 // otherwise there will be contention in PropagateOutputs.
588 // This is best effort policy.
589 if (may_steal_blocking_work &&
590 tws->GetInflightTaskCount(true) < kMaxBlockingInflight) {
591 t = tws->PopBlockingTask();
592 if (t.f) {
593 break;
594 }
595 }
596 if (i == 0) {
597 // Always look for any work from the "primary" work source.
598 // This way when we wake up a thread for a new closure we are
599 // guaranteed it can be worked on.
600 t = tws->PopNonBlockingTask(thread_id, true);
601 if (t.f) {
602 task_from_blocking_queue = false;
603 break;
604 }
605 if (t.f) {
606 break;
607 }
608 } else {
609 t = tws->PopNonBlockingTask(thread_id, false);
610 if (t.f) {
611 task_from_blocking_queue = false;
612 break;
613 }
614 }
615 }
616 }
617 if (t.f) {
618 profiler::TraceMe activity(
619 [=] {
620 return strings::StrCat(task_from_blocking_queue ? "inter" : "intra",
621 " #id = ", tws->GetTracemeId(), " ",
622 thread_id, "#");
623 },
624 profiler::TraceMeLevel::kInfo);
625 VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
626 << " work from " << tws->GetTracemeId();
627 tws->IncrementInflightTaskCount(task_from_blocking_queue);
628 env_.ExecuteTask(t);
629 tws->DecrementInflightTaskCount(task_from_blocking_queue);
630 } else {
631 profiler::TraceMe activity(
632 [=] {
633 return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
634 },
635 profiler::TraceMeLevel::kInfo);
636 if (VLOG_IS_ON(4)) {
637 for (int i = 0; i < thread_work_sources->size(); ++i) {
638 VLOG(4) << "source id " << i << " "
639 << (*thread_work_sources)[i]->ToString();
640 }
641 }
642 if (use_sub_thread_pool_) {
643 WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
644 } else {
645 WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
646 }
647 }
648 }
649 }
650
WaitForWorkInSubThreadPool(bool is_blocking,int sub_thread_pool_id)651 void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
652 int sub_thread_pool_id) {
653 const int kMaxSleepMicros = 250;
654
655 // The non-blocking thread will just sleep.
656 if (!is_blocking) {
657 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
658 return;
659 }
660
661 thread_local Waiter waiter;
662 WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
663 &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
664 }
665
WaitForWork(bool is_blocking,int thread_id,int32 max_blocking_inflight)666 void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
667 int32 max_blocking_inflight) {
668 const int kMaxSleepMicros = 250;
669
670 // The non-blocking thread will just sleep.
671 if (!is_blocking) {
672 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
673 return;
674 }
675
676 ThreadWorkSource* tws = nullptr;
677 {
678 mutex_lock l(thread_data_[thread_id].mu);
679 if (thread_data_[thread_id].new_version >
680 thread_data_[thread_id].current_version) {
681 thread_data_[thread_id].current_thread_work_sources.swap(
682 thread_data_[thread_id].new_thread_work_sources);
683 thread_data_[thread_id].current_version =
684 thread_data_[thread_id].new_version;
685 }
686 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
687 thread_data_[thread_id].current_thread_work_sources.get();
688 while (!cancelled_ && thread_work_sources->empty()) {
689 // Wait until there is new request
690 thread_data_[thread_id].sources_not_empty.wait(l);
691 if (thread_data_[thread_id].new_version >
692 thread_data_[thread_id].current_version) {
693 thread_data_[thread_id].current_thread_work_sources.swap(
694 thread_data_[thread_id].new_thread_work_sources);
695 thread_data_[thread_id].current_version =
696 thread_data_[thread_id].new_version;
697 thread_work_sources =
698 thread_data_[thread_id].current_thread_work_sources.get();
699 }
700 }
701 if (cancelled_) {
702 return;
703 }
704 tws = (*thread_work_sources)[0];
705 }
706
707 if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) {
708 // Sleep to reduce contention in PropagateOutputs
709 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
710 }
711 tws->WaitForWork(kMaxSleepMicros);
712 }
713
714 } // namespace internal
715
716 // Contains the concrete implementation of the RunHandler.
717 // Externally visible RunHandler class simply forwards the work to this one.
718 class RunHandler::Impl {
719 public:
720 explicit Impl(RunHandlerPool::Impl* pool_impl);
721
~Impl()722 ~Impl() {}
723
thread_pool_interface()724 thread::ThreadPoolInterface* thread_pool_interface() {
725 return thread_pool_interface_.get();
726 }
727
728 // Stores now time (in microseconds) since unix epoch when the handler is
729 // requested via RunHandlerPool::Get().
start_time_us() const730 uint64 start_time_us() const { return start_time_us_; }
step_id() const731 int64 step_id() const { return step_id_; }
732 void ScheduleInterOpClosure(std::function<void()> fn);
733 void ScheduleIntraOpClosure(std::function<void()> fn);
734
735 void Reset(int64 step_id,
736 const RunOptions::Experimental::RunHandlerPoolOptions& options);
737
pool_impl()738 RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
739
tws()740 internal::ThreadWorkSource* tws() { return &tws_; }
741
priority()742 int64 priority() { return options_.priority(); }
743
744 private:
745 class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
746 public:
ThreadPoolInterfaceWrapper(Impl * run_handler_impl)747 explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
748 : run_handler_impl_(run_handler_impl) {}
~ThreadPoolInterfaceWrapper()749 ~ThreadPoolInterfaceWrapper() override {}
750 void Schedule(std::function<void()> fn) override;
751 int NumThreads() const override;
752 int CurrentThreadId() const override;
753
754 private:
755 RunHandler::Impl* run_handler_impl_ = nullptr;
756 };
757
758 RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
759 uint64 start_time_us_;
760 int64 step_id_;
761 std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
762 internal::ThreadWorkSource tws_;
763 RunOptions::Experimental::RunHandlerPoolOptions options_;
764 };
765
766 // Contains shared state across all run handlers present in the pool. Also
767 // responsible for pool management decisions.
768 // This class is thread safe.
769 class RunHandlerPool::Impl {
770 public:
Impl(int num_inter_op_threads,int num_intra_op_threads)771 explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
772 : max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
773 "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
774 waiters_mu_(
775 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
776 queue_waiters_(
777 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
778 run_handler_thread_pool_(new internal::RunHandlerThreadPool(
779 num_inter_op_threads, num_intra_op_threads, Env::Default(),
780 ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
781 &queue_waiters_)),
782 iterations_(0),
783 version_(0),
784 sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
785 "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
786 std::vector<double>({1}))) {
787 VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
788 free_handlers_.reserve(max_handlers_);
789 handlers_.reserve(max_handlers_);
790 for (int i = 0; i < max_handlers_; ++i) {
791 handlers_.emplace_back(new RunHandler::Impl(this));
792 free_handlers_.push_back(handlers_.back().get());
793 }
794 queue_waiters_.resize(
795 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
796 waiters_mu_.resize(
797 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
798 for (auto& queue_waiter : queue_waiters_) {
799 queue_waiter.next = &queue_waiter;
800 queue_waiter.prev = &queue_waiter;
801 }
802 run_handler_thread_pool_->Start();
803 }
804
~Impl()805 ~Impl() {
806 // Sanity check that all handlers have been returned back to the pool before
807 // destruction.
808 DCHECK_EQ(handlers_.size(), max_handlers_);
809 DCHECK_EQ(free_handlers_.size(), handlers_.size());
810 DCHECK_EQ(sorted_active_handlers_.size(), 0);
811 // Stop the threads in run_handler_thread_pool_ before freeing other
812 // pointers. Otherwise a thread may try to access a pointer after the
813 // pointer has been freed.
814 run_handler_thread_pool_.reset();
815 }
816
run_handler_thread_pool()817 internal::RunHandlerThreadPool* run_handler_thread_pool() {
818 return run_handler_thread_pool_.get();
819 }
820
has_free_handler()821 bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
822 return !free_handlers_.empty();
823 }
824
Get(int64 step_id,int64 timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)825 std::unique_ptr<RunHandler> Get(
826 int64 step_id, int64 timeout_in_ms,
827 const RunOptions::Experimental::RunHandlerPoolOptions& options)
828 TF_LOCKS_EXCLUDED(mu_) {
829 thread_local std::unique_ptr<
830 Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
831 thread_work_sources =
832 std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
833 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
834 static_cast<int32>(ParamFromEnvWithDefault(
835 "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
836 kMaxConcurrentHandlers))));
837 uint64 version;
838 int num_active_requests;
839 RunHandler::Impl* handler_impl;
840 {
841 mutex_lock l(mu_);
842 if (!has_free_handler()) {
843 profiler::TraceMe activity(
844 [&] {
845 return strings::StrCat("WaitingForHandler#step_id=", step_id,
846 "#");
847 },
848 profiler::TraceMeLevel::kInfo);
849 if (timeout_in_ms == 0) {
850 mu_.Await(Condition(this, &Impl::has_free_handler));
851 } else if (!mu_.AwaitWithDeadline(
852 Condition(this, &Impl::has_free_handler),
853 EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) {
854 return nullptr;
855 }
856 }
857 // Remove the last entry from free_handlers_ and add to the end of
858 // sorted_active_handlers_.
859 handler_impl = free_handlers_.back();
860 handler_impl->Reset(step_id, options);
861 free_handlers_.pop_back();
862
863 num_active_requests = sorted_active_handlers_.size() + 1;
864 thread_work_sources->resize(num_active_requests);
865 int priority = options.priority();
866 auto it = sorted_active_handlers_.cbegin();
867 bool new_handler_inserted = false;
868 for (int i = 0; i < num_active_requests; ++i) {
869 if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
870 priority > (*it)->priority())) {
871 sorted_active_handlers_.insert(it, handler_impl);
872 new_handler_inserted = true;
873 // Point to the newly added handler.
874 --it;
875 }
876 (*thread_work_sources)[i] = (*it)->tws();
877 ++it;
878 }
879 version = ++version_;
880 }
881 RecomputePoolStats(num_active_requests, version, *thread_work_sources);
882 return WrapUnique<RunHandler>(new RunHandler(handler_impl));
883 }
884
ReleaseHandler(RunHandler::Impl * handler)885 void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
886 mutex_lock l(mu_);
887 DCHECK_GT(sorted_active_handlers_.size(), 0);
888
889 CHECK_EQ(handler->tws()->TaskQueueSize(true), 0); // Crash OK.
890 CHECK_EQ(handler->tws()->TaskQueueSize(false), 0); // Crash OK.
891
892 uint64 now = tensorflow::EnvTime::NowMicros();
893 double elapsed = (now - handler->start_time_us()) / 1000.0;
894 time_hist_.Add(elapsed);
895
896 // Erase from and update sorted_active_handlers_. Add it to the end of
897 // free_handlers_.
898 auto iter = std::find(sorted_active_handlers_.begin(),
899 sorted_active_handlers_.end(), handler);
900 DCHECK(iter != sorted_active_handlers_.end())
901 << "Unexpected handler: " << handler
902 << " is being requested for release";
903
904 // Remove this handler from this list and add it to the list of free
905 // handlers.
906 sorted_active_handlers_.erase(iter);
907 free_handlers_.push_back(handler);
908 DCHECK_LE(free_handlers_.size(), max_handlers_);
909 LogInfo();
910
911 // We do not recompute pool stats during release. The side effect is that
912 // there may be empty thread work sources in the queue. However, any new
913 // requests will trigger recomputation.
914 }
915
GetActiveHandlerPrioritiesForTesting()916 std::vector<int64> GetActiveHandlerPrioritiesForTesting()
917 TF_LOCKS_EXCLUDED(mu_) {
918 mutex_lock l(mu_);
919 std::vector<int64> ret;
920 for (const auto& handler_impl : sorted_active_handlers_) {
921 ret.push_back(handler_impl->priority());
922 }
923 return ret;
924 }
925
926 private:
927 void RecomputePoolStats(
928 int num_active_requests, uint64 version,
929 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
930 thread_work_sources);
931
932 void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
933
934 // Maximum number of handlers pre-created during pool construction time. The
935 // number has been chosen expecting each handler might at least want 1
936 // inter-op thread for execution (during compute intensive workloads like
937 // inference).
938 const int max_handlers_;
939
940 Eigen::MaxSizeVector<mutex> waiters_mu_;
941 Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
942
943 std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
944 // Thread compatible part used only by lock under RunHandlerPool.
945 // Handlers are sorted by start time.
946 // TODO(azaks): sort by the remaining latency budget.
947 // TODO(chaox): Consider other data structure for maintaining the sorted
948 // active handlers if the searching overhead(currently O(n)) becomes the
949 // bottleneck.
950 std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
951 std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
952 std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
953
954 // Histogram of elapsed runtime of every handler (in ms).
955 histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
956
957 int64 iterations_ TF_GUARDED_BY(mu_);
958 mutex mu_;
959 int64 version_ TF_GUARDED_BY(mu_);
960 const std::vector<double> sub_thread_pool_end_request_percentage_;
961 };
962
RecomputePoolStats(int num_active_requests,uint64 version,const Eigen::MaxSizeVector<internal::ThreadWorkSource * > & thread_work_sources)963 void RunHandlerPool::Impl::RecomputePoolStats(
964 int num_active_requests, uint64 version,
965 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
966 thread_work_sources) {
967 if (num_active_requests == 0) return;
968
969 int sub_thread_pool_id = 0;
970 for (int i = 0; i < num_active_requests; ++i) {
971 while (
972 sub_thread_pool_id <
973 sub_thread_pool_end_request_percentage_.size() - 1 &&
974 i >= num_active_requests *
975 sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
976 sub_thread_pool_id++;
977 }
978 thread_work_sources[i]->SetWaiter(version,
979 &queue_waiters_[sub_thread_pool_id],
980 &waiters_mu_[sub_thread_pool_id]);
981 }
982
983 int num_threads = run_handler_thread_pool()->NumThreads();
984 int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
985 int num_non_blocking_threads = num_threads - num_blocking_threads;
986
987 std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution(
988 num_active_requests, num_blocking_threads);
989 for (int i = 0; i < num_blocking_threads; ++i) {
990 VLOG(2) << "Set work for tid=" << i
991 << " with start_request_idx=" << request_idx_list[i];
992 run_handler_thread_pool()->SetThreadWorkSources(
993 i, request_idx_list[i], version, thread_work_sources);
994 }
995
996 request_idx_list = ChooseRequestsWithExponentialDistribution(
997 num_active_requests, num_non_blocking_threads);
998 for (int i = 0; i < num_non_blocking_threads; ++i) {
999 VLOG(2) << "Set work for tid=" << (i + num_blocking_threads)
1000 << " with start_request_idx=" << request_idx_list[i];
1001 run_handler_thread_pool()->SetThreadWorkSources(
1002 i + num_blocking_threads, request_idx_list[i], version,
1003 thread_work_sources);
1004 }
1005 }
1006
LogInfo()1007 void RunHandlerPool::Impl::LogInfo() {
1008 if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
1009 int num_active_requests = sorted_active_handlers_.size();
1010 VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
1011 VLOG(1) << "Active session runs: " << num_active_requests;
1012 uint64 now = tensorflow::Env::Default()->NowMicros();
1013 string times_str = "";
1014 string ids_str = "";
1015 auto it = sorted_active_handlers_.cbegin();
1016 for (int i = 0; i < num_active_requests; ++i) {
1017 if (i > 0) {
1018 times_str += " ";
1019 ids_str += " ";
1020 }
1021
1022 times_str +=
1023 strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms.");
1024 ids_str += strings::StrCat((*it)->tws()->GetTracemeId());
1025 ++it;
1026 }
1027 VLOG(1) << "Elapsed times are: " << times_str;
1028 VLOG(1) << "Step ids are: " << ids_str;
1029 }
1030 }
1031
1032 // It is important to return a value such as:
1033 // CurrentThreadId() in [0, NumThreads)
NumThreads() const1034 int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
1035 return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
1036 }
1037
CurrentThreadId() const1038 int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
1039 return run_handler_impl_->pool_impl_->run_handler_thread_pool()
1040 ->CurrentThreadId();
1041 }
1042
Schedule(std::function<void ()> fn)1043 void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
1044 std::function<void()> fn) {
1045 return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
1046 }
1047
Impl(RunHandlerPool::Impl * pool_impl)1048 RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
1049 : pool_impl_(pool_impl) {
1050 thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
1051 Reset(0, RunOptions::Experimental::RunHandlerPoolOptions());
1052 }
1053
ScheduleInterOpClosure(std::function<void ()> fn)1054 void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
1055 VLOG(3) << "Scheduling inter work for " << tws()->GetTracemeId();
1056 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
1057 std::move(fn));
1058 }
1059
ScheduleIntraOpClosure(std::function<void ()> fn)1060 void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
1061 VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1062 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1063 std::move(fn));
1064 }
1065
Reset(int64 step_id,const RunOptions::Experimental::RunHandlerPoolOptions & options)1066 void RunHandler::Impl::Reset(
1067 int64 step_id,
1068 const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1069 start_time_us_ = tensorflow::Env::Default()->NowMicros();
1070 step_id_ = step_id;
1071 options_ = options;
1072 tws_.SetTracemeId(step_id);
1073 }
1074
RunHandlerPool(int num_inter_op_threads)1075 RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
1076 : impl_(new Impl(num_inter_op_threads, 0)) {}
1077
RunHandlerPool(int num_inter_op_threads,int num_intra_op_threads)1078 RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
1079 int num_intra_op_threads)
1080 : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
1081
~RunHandlerPool()1082 RunHandlerPool::~RunHandlerPool() {}
1083
Get(int64 step_id,int64 timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)1084 std::unique_ptr<RunHandler> RunHandlerPool::Get(
1085 int64 step_id, int64 timeout_in_ms,
1086 const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1087 return impl_->Get(step_id, timeout_in_ms, options);
1088 }
1089
GetActiveHandlerPrioritiesForTesting() const1090 std::vector<int64> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1091 const {
1092 return impl_->GetActiveHandlerPrioritiesForTesting();
1093 }
1094
RunHandler(Impl * impl)1095 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1096
ScheduleInterOpClosure(std::function<void ()> fn)1097 void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
1098 impl_->ScheduleInterOpClosure(std::move(fn));
1099 }
1100
AsIntraThreadPoolInterface()1101 thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
1102 return impl_->thread_pool_interface();
1103 }
1104
~RunHandler()1105 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1106
1107 } // namespace tensorflow
1108