• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/framework/run_handler.h"
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <list>
23 #include <memory>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/run_handler_util.h"
27 #include "tensorflow/core/lib/core/threadpool_interface.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/platform/context.h"
30 #include "tensorflow/core/platform/denormal.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/numa.h"
33 #include "tensorflow/core/platform/setround.h"
34 #include "tensorflow/core/platform/tracing.h"
35 #include "tensorflow/core/profiler/lib/traceme.h"
36 #include "tensorflow/core/util/ptr_util.h"
37 
38 namespace tensorflow {
39 namespace {
40 // LINT.IfChange
41 static constexpr int32 kMaxConcurrentHandlers = 128;
42 // LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
43 
44 typedef typename internal::RunHandlerEnvironment::Task Task;
45 typedef Eigen::RunQueue<Task, 1024> Queue;
46 
47 }  // namespace
48 
49 namespace internal {
RunHandlerEnvironment(Env * env,const ThreadOptions & thread_options,const string & name)50 RunHandlerEnvironment::RunHandlerEnvironment(
51     Env* env, const ThreadOptions& thread_options, const string& name)
52     : env_(env), thread_options_(thread_options), name_(name) {}
53 
CreateThread(std::function<void ()> f)54 RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
55     std::function<void()> f) {
56   return env_->StartThread(thread_options_, name_, [=]() {
57     // Set the processor flag to flush denormals to zero.
58     port::ScopedFlushDenormal flush;
59     // Set the processor rounding mode to ROUND TO NEAREST.
60     port::ScopedSetRound round(FE_TONEAREST);
61     if (thread_options_.numa_node != port::kNUMANoAffinity) {
62       port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
63     }
64     f();
65   });
66 }
67 
CreateTask(std::function<void ()> f)68 RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
69     std::function<void()> f) {
70   uint64 id = 0;
71   if (tracing::EventCollector::IsEnabled()) {
72     id = tracing::GetUniqueArg();
73     tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
74   }
75   return Task{
76       std::unique_ptr<TaskImpl>(new TaskImpl{
77           std::move(f),
78           Context(ContextKind::kThread),
79           id,
80       }),
81   };
82 }
83 
ExecuteTask(const Task & t)84 void RunHandlerEnvironment::ExecuteTask(const Task& t) {
85   WithContext wc(t.f->context);
86   tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
87                                t.f->trace_id);
88   t.f->f();
89 }
90 
WaitOnWaiter(Waiter * waiter,Waiter * queue_head,mutex * mutex,int max_sleep_micros)91 void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
92                   int max_sleep_micros) {
93   {
94     mutex_lock l(*mutex);
95     CHECK_EQ(waiter->next, waiter);  // Crash OK.
96     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
97 
98     // Add waiter to the LIFO queue
99     waiter->prev = queue_head;
100     waiter->next = queue_head->next;
101     waiter->next->prev = waiter;
102     waiter->prev->next = waiter;
103   }
104   {
105     mutex_lock l(waiter->mu);
106     // Wait on the condition variable
107     waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108   }
109 
110   mutex_lock l(*mutex);
111   // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
112   // to a notification we cannot conclude the waiter is not in the queue.
113   // This is due to the fact that a thread preempted right before notifying
114   // may resume after a waiter got re-added.
115   if (waiter->next != waiter) {
116     CHECK(waiter->prev != waiter);  // Crash OK.
117     waiter->next->prev = waiter->prev;
118     waiter->prev->next = waiter->next;
119     waiter->next = waiter;
120     waiter->prev = waiter;
121   } else {
122     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
123   }
124 }
125 
ThreadWorkSource()126 ThreadWorkSource::ThreadWorkSource()
127     : non_blocking_work_sharding_factor_(
128           static_cast<int32>(ParamFromEnvWithDefault(
129               "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
130       non_blocking_work_queues_(non_blocking_work_sharding_factor_),
131       blocking_inflight_(0),
132       non_blocking_inflight_(0),
133       traceme_id_(0),
134       version_(0),
135       sub_thread_pool_waiter_(nullptr) {
136   queue_waiters_.next = &queue_waiters_;
137   queue_waiters_.prev = &queue_waiters_;
138   for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
139     non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
140   }
141 }
142 
~ThreadWorkSource()143 ThreadWorkSource::~ThreadWorkSource() {
144   for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
145     delete non_blocking_work_queues_[i];
146   }
147 }
148 
EnqueueTask(Task t,bool is_blocking)149 Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
150   mutex* mu = nullptr;
151   Queue* task_queue = nullptr;
152   thread_local int64 closure_counter = 0;
153 
154   if (!is_blocking) {
155     int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
156     task_queue = &(non_blocking_work_queues_[queue_index]->queue);
157     mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
158   } else {
159     task_queue = &blocking_work_queue_;
160     mu = &blocking_queue_op_mu_;
161   }
162 
163   {
164     mutex_lock l(*mu);
165     // For a given queue, only one thread can call PushFront.
166     t = task_queue->PushFront(std::move(t));
167   }
168 
169   Waiter* w = nullptr;
170   static const bool use_sub_thread_pool =
171       ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
172 
173   Waiter* waiter_queue;
174   mutex* waiter_queue_mu;
175   if (use_sub_thread_pool) {
176     // When we use multiple sub thread pools, free threads wait on sub
177     // thread pool waiting queues. Wake up threads from sub thread waiting
178     // queues.
179     // The waiting queues are defined at RunHandlerPool.
180     // Get the waiter_queue and corresponding mutex. Note, the thread work
181     // source may change afterwards if a new request comes or an old request
182     // finishes.
183     tf_shared_lock lock(run_handler_waiter_mu_);
184     waiter_queue = sub_thread_pool_waiter_;
185     waiter_queue_mu = sub_thread_pool_waiter_mu_;
186   } else {
187     waiter_queue = &queue_waiters_;
188     waiter_queue_mu = &waiters_mu_;
189   }
190   {
191     mutex_lock l(*waiter_queue_mu);
192     if (waiter_queue->next != waiter_queue) {
193       // Remove waiter from the LIFO queue
194       w = waiter_queue->next;
195 
196       CHECK(w->prev != w);  // Crash OK.
197       CHECK(w->next != w);  // Crash OK.
198 
199       w->next->prev = w->prev;
200       w->prev->next = w->next;
201 
202       // Use `w->next == &w` to indicate that the waiter has been removed
203       // from the queue.
204       w->next = w;
205       w->prev = w;
206     }
207   }
208   if (w != nullptr) {
209     // We call notify_one() without any locks, so we can miss notifications.
210     // The wake up logic is best effort and a thread will wake in short
211     // period of time in case a notification is missed.
212     w->cv.notify_one();
213   }
214   VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
215           << traceme_id_.load(std::memory_order_relaxed);
216   return t;
217 }
218 
PopBlockingTask()219 Task ThreadWorkSource::PopBlockingTask() {
220   return blocking_work_queue_.PopBack();
221 }
222 
PopNonBlockingTask(int start_index,bool search_from_all_queue)223 Task ThreadWorkSource::PopNonBlockingTask(int start_index,
224                                           bool search_from_all_queue) {
225   Task t;
226   unsigned sharding_factor = NonBlockingWorkShardingFactor();
227   for (unsigned j = 0; j < sharding_factor; ++j) {
228     t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
229             ->queue.PopBack();
230     if (t.f) {
231       return t;
232     }
233     if (!search_from_all_queue) {
234       break;
235     }
236   }
237   return t;
238 }
239 
WaitForWork(int max_sleep_micros)240 void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
241   thread_local Waiter waiter;
242   WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
243 }
244 
TaskQueueSize(bool is_blocking)245 int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
246   if (is_blocking) {
247     return blocking_work_queue_.Size();
248   } else {
249     unsigned total_size = 0;
250     for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
251       total_size += non_blocking_work_queues_[i]->queue.Size();
252     }
253     return total_size;
254   }
255 }
256 
GetTracemeId()257 int64 ThreadWorkSource::GetTracemeId() {
258   return traceme_id_.load(std::memory_order_relaxed);
259 }
260 
SetTracemeId(int64 value)261 void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
262 
SetWaiter(uint64 version,Waiter * waiter,mutex * mutex)263 void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
264   {
265     tf_shared_lock lock(run_handler_waiter_mu_);
266     // Most of the request won't change sub pool for recomputation.
267     // Optimization for avoiding holding exclusive lock to reduce contention.
268     if (sub_thread_pool_waiter_ == waiter) {
269       return;
270     }
271     // If the current version is a newer version, no need to update.
272     if (version_ > version) {
273       return;
274     }
275   }
276 
277   mutex_lock l(run_handler_waiter_mu_);
278   sub_thread_pool_waiter_ = waiter;
279   sub_thread_pool_waiter_mu_ = mutex;
280   version_ = version;
281 }
282 
GetInflightTaskCount(bool is_blocking)283 int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
284   std::atomic<int64>* counter =
285       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
286   return counter->load(std::memory_order_relaxed);
287 }
288 
IncrementInflightTaskCount(bool is_blocking)289 void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
290   std::atomic<int64>* counter =
291       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
292   counter->fetch_add(1, std::memory_order_relaxed);
293 }
294 
DecrementInflightTaskCount(bool is_blocking)295 void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
296   std::atomic<int64>* counter =
297       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
298   counter->fetch_sub(1, std::memory_order_relaxed);
299 }
300 
NonBlockingWorkShardingFactor()301 unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
302   return non_blocking_work_sharding_factor_;
303 }
304 
ToString()305 std::string ThreadWorkSource::ToString() {
306   return strings::StrCat("traceme_id = ", GetTracemeId(),
307                          ", inter queue size = ", TaskQueueSize(true),
308                          ", inter inflight = ", GetInflightTaskCount(true),
309                          ", intra queue size = ", TaskQueueSize(false),
310                          ", intra inflight = ", GetInflightTaskCount(false));
311 }
312 
RunHandlerThreadPool(int num_blocking_threads,int num_non_blocking_threads,Env * env,const ThreadOptions & thread_options,const string & name,Eigen::MaxSizeVector<mutex> * waiters_mu,Eigen::MaxSizeVector<Waiter> * queue_waiters)313 RunHandlerThreadPool::RunHandlerThreadPool(
314     int num_blocking_threads, int num_non_blocking_threads, Env* env,
315     const ThreadOptions& thread_options, const string& name,
316     Eigen::MaxSizeVector<mutex>* waiters_mu,
317     Eigen::MaxSizeVector<Waiter>* queue_waiters)
318     : num_threads_(num_blocking_threads + num_non_blocking_threads),
319       num_blocking_threads_(num_blocking_threads),
320       num_non_blocking_threads_(num_non_blocking_threads),
321       thread_data_(num_threads_),
322       env_(env, thread_options, name),
323       name_(name),
324       waiters_mu_(waiters_mu),
325       queue_waiters_(queue_waiters),
326       use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
327           "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
328       num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
329           "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
330           std::vector<int>({num_blocking_threads / 2,
331                             num_blocking_threads - num_blocking_threads / 2}))),
332       sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
333           "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
334           std::vector<double>({0, 0.4}))),
335       sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
336           "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
337           std::vector<double>({0.4, 1}))) {
338   thread_data_.resize(num_threads_);
339   VLOG(1) << "Creating RunHandlerThreadPool " << name << " with  "
340           << num_blocking_threads_ << " blocking threads and "
341           << num_non_blocking_threads_ << " non-blocking threads.";
342 }
343 
~RunHandlerThreadPool()344 RunHandlerThreadPool::~RunHandlerThreadPool() {
345   VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
346 
347   cancelled_ = true;
348   for (size_t i = 0; i < thread_data_.size(); ++i) {
349     {
350       mutex_lock l(thread_data_[i].mu);
351       thread_data_[i].sources_not_empty.notify_all();
352     }
353     thread_data_[i].thread.reset();
354   }
355 }
356 
Start()357 void RunHandlerThreadPool::Start() {
358   cancelled_ = false;
359   int num_blocking_threads = num_blocking_threads_;
360   for (int i = 0; i < num_threads_; i++) {
361     int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
362     for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
363       if (i < num_threads_in_sub_thread_pool_[j]) {
364         sub_thread_pool_id = j;
365         break;
366       }
367     }
368     thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
369     thread_data_[i].thread.reset(
370         env_.CreateThread([this, i, num_blocking_threads]() {
371           WorkerLoop(i, i < num_blocking_threads);
372         }));
373   }
374 }
375 
StartOneThreadForTesting()376 void RunHandlerThreadPool::StartOneThreadForTesting() {
377   cancelled_ = false;
378   thread_data_[0].sub_thread_pool_id = 0;
379   thread_data_[0].thread.reset(
380       env_.CreateThread([this]() { WorkerLoop(0, true); }));
381 }
382 
AddWorkToQueue(ThreadWorkSource * tws,bool is_blocking,std::function<void ()> fn)383 void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
384                                           bool is_blocking,
385                                           std::function<void()> fn) {
386   Task t = env_.CreateTask(std::move(fn));
387   t = tws->EnqueueTask(std::move(t), is_blocking);
388   if (t.f) {
389     VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
390             << tws->GetTracemeId();
391     env_.ExecuteTask(t);
392   }
393 }
394 
395 // TODO(donglin) Change the task steal order to be round-robin such that if
396 // an attempt to steal task from request i failed, then attempt to steal task
397 // from the next request in terms of the arrival time. This approach may
398 // provide better performance due to less lock retention. The drawback is that
399 // the profiler will be a bit harder to read.
SetThreadWorkSources(int tid,int start_request_idx,uint64 version,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources)400 void RunHandlerThreadPool::SetThreadWorkSources(
401     int tid, int start_request_idx, uint64 version,
402     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
403   mutex_lock l(thread_data_[tid].mu);
404   if (version > thread_data_[tid].new_version) {
405     thread_data_[tid].new_version = version;
406   } else {
407     // A newer version is already updated. No need to update.
408     return;
409   }
410   thread_data_[tid].new_thread_work_sources->resize(0);
411   if (use_sub_thread_pool_) {
412     for (int i = 0; i < thread_work_sources.size(); ++i) {
413       thread_data_[tid].new_thread_work_sources->emplace_back(
414           thread_work_sources[i]);
415     }
416   } else {
417     thread_data_[tid].new_thread_work_sources->emplace_back(
418         thread_work_sources[start_request_idx]);
419     // The number of shards for the queue. Threads in each shard will
420     // prioritize different thread_work_sources. Increase the number of shards
421     // could decrease the contention in the queue. For example, when
422     // num_shards == 1: thread_work_sources are ordered as start_request_idx,
423     // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
424     // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
425     // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
426     // 4... for the other half of the threads.
427     static const int num_shards =
428         ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
429     int token = tid % num_shards;
430     for (int i = 0; i < num_shards; ++i) {
431       for (int j = token; j < thread_work_sources.size(); j += num_shards) {
432         if (j != start_request_idx) {
433           thread_data_[tid].new_thread_work_sources->emplace_back(
434               thread_work_sources[j]);
435         }
436       }
437       token = (token + 1) % num_shards;
438     }
439     thread_data_[tid].sources_not_empty.notify_all();
440   }
441 }
442 
GetPerThread()443 RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
444   thread_local RunHandlerThreadPool::PerThread per_thread_;
445   RunHandlerThreadPool::PerThread* pt = &per_thread_;
446   return pt;
447 }
448 
CurrentThreadId() const449 int RunHandlerThreadPool::CurrentThreadId() const {
450   const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
451   if (pt->pool == this) {
452     return pt->thread_id;
453   } else {
454     return -1;
455   }
456 }
457 
NumThreads() const458 int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
459 
NumBlockingThreads() const460 int RunHandlerThreadPool::NumBlockingThreads() const {
461   return num_blocking_threads_;
462 }
463 
NumNonBlockingThreads() const464 int RunHandlerThreadPool::NumNonBlockingThreads() const {
465   return num_non_blocking_threads_;
466 }
467 
ThreadData()468 RunHandlerThreadPool::ThreadData::ThreadData()
469     : new_version(0),
470       current_index(0),
471       new_thread_work_sources(
472           new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
473               ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
474                                       kMaxConcurrentHandlers)))),
475       current_version(0),
476       current_thread_work_sources(
477           new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
478               ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
479                                       kMaxConcurrentHandlers)))) {}
480 
FindTask(int searching_range_start,int searching_range_end,int thread_id,int sub_thread_pool_id,int max_blocking_inflight,bool may_steal_blocking_work,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources,bool * task_from_blocking_queue,ThreadWorkSource ** tws)481 Task RunHandlerThreadPool::FindTask(
482     int searching_range_start, int searching_range_end, int thread_id,
483     int sub_thread_pool_id, int max_blocking_inflight,
484     bool may_steal_blocking_work,
485     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
486     bool* task_from_blocking_queue, ThreadWorkSource** tws) {
487   Task t;
488   int current_index = thread_data_[thread_id].current_index;
489   *task_from_blocking_queue = false;
490 
491   for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
492     if (current_index >= searching_range_end ||
493         current_index < searching_range_start) {
494       current_index = searching_range_start;
495     }
496     *tws = thread_work_sources[current_index];
497     ++current_index;
498 
499     // For blocking thread, search for blocking tasks first.
500     if (may_steal_blocking_work &&
501         (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
502       t = (*tws)->PopBlockingTask();
503       if (t.f) {
504         *task_from_blocking_queue = true;
505         break;
506       }
507     }
508 
509     // Search for non-blocking tasks.
510     t = (*tws)->PopNonBlockingTask(thread_id, true);
511     if (t.f) {
512       break;
513     }
514   }
515   thread_data_[thread_id].current_index = current_index;
516   return t;
517 }
518 
519 // Main worker thread loop.
WorkerLoop(int thread_id,bool may_steal_blocking_work)520 void RunHandlerThreadPool::WorkerLoop(int thread_id,
521                                       bool may_steal_blocking_work) {
522   PerThread* pt = GetPerThread();
523   pt->pool = this;
524   pt->thread_id = thread_id;
525   static constexpr int32 kMaxBlockingInflight = 10;
526 
527   while (!cancelled_) {
528     Task t;
529     ThreadWorkSource* tws = nullptr;
530     bool task_from_blocking_queue = true;
531     int sub_thread_pool_id;
532     // Get the current thread work sources.
533     {
534       mutex_lock l(thread_data_[thread_id].mu);
535       if (thread_data_[thread_id].current_version <
536           thread_data_[thread_id].new_version) {
537         thread_data_[thread_id].current_version =
538             thread_data_[thread_id].new_version;
539         thread_data_[thread_id].current_thread_work_sources.swap(
540             thread_data_[thread_id].new_thread_work_sources);
541       }
542     }
543     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
544         thread_data_[thread_id].current_thread_work_sources.get();
545     if (use_sub_thread_pool_) {
546       sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
547       int active_requests = thread_work_sources->size();
548       if (may_steal_blocking_work) {
549         // Each thread will first look for tasks from requests that belongs to
550         // its sub thread pool.
551         int search_range_start =
552             active_requests *
553             sub_thread_pool_start_request_percentage_[sub_thread_pool_id];
554         int search_range_end =
555             active_requests *
556             sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
557         search_range_end =
558             std::min(active_requests,
559                      std::max(search_range_end, search_range_start + 1));
560 
561         t = FindTask(search_range_start, search_range_end, thread_id,
562                      sub_thread_pool_id, kMaxBlockingInflight,
563                      /*may_steal_blocking_work=*/true, *thread_work_sources,
564                      &task_from_blocking_queue, &tws);
565         if (!t.f) {
566           // Search from all requests if the thread cannot find tasks from
567           // requests that belong to its own sub thread pool.
568           t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
569                        kMaxBlockingInflight,
570                        /*may_steal_blocking_work=*/true, *thread_work_sources,
571                        &task_from_blocking_queue, &tws);
572         }
573       } else {
574         // For non-blocking threads, it will always search from all pending
575         // requests.
576         t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
577                      kMaxBlockingInflight,
578                      /*may_steal_blocking_work=*/false, *thread_work_sources,
579                      &task_from_blocking_queue, &tws);
580       }
581     } else {
582       // TODO(chaox): Refactor the following code to share the logic with
583       // FindTask.
584       for (int i = 0; i < thread_work_sources->size(); ++i) {
585         tws = (*thread_work_sources)[i];
586         // We want a smallish numbers of inter threads since
587         // otherwise there will be contention in PropagateOutputs.
588         // This is best effort policy.
589         if (may_steal_blocking_work &&
590             tws->GetInflightTaskCount(true) < kMaxBlockingInflight) {
591           t = tws->PopBlockingTask();
592           if (t.f) {
593             break;
594           }
595         }
596         if (i == 0) {
597           // Always look for any work from the "primary" work source.
598           // This way when we wake up a thread for a new closure we are
599           // guaranteed it can be worked on.
600           t = tws->PopNonBlockingTask(thread_id, true);
601           if (t.f) {
602             task_from_blocking_queue = false;
603             break;
604           }
605           if (t.f) {
606             break;
607           }
608         } else {
609           t = tws->PopNonBlockingTask(thread_id, false);
610           if (t.f) {
611             task_from_blocking_queue = false;
612             break;
613           }
614         }
615       }
616     }
617     if (t.f) {
618       profiler::TraceMe activity(
619           [=] {
620             return strings::StrCat(task_from_blocking_queue ? "inter" : "intra",
621                                    " #id = ", tws->GetTracemeId(), " ",
622                                    thread_id, "#");
623           },
624           profiler::TraceMeLevel::kInfo);
625       VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
626               << " work from " << tws->GetTracemeId();
627       tws->IncrementInflightTaskCount(task_from_blocking_queue);
628       env_.ExecuteTask(t);
629       tws->DecrementInflightTaskCount(task_from_blocking_queue);
630     } else {
631       profiler::TraceMe activity(
632           [=] {
633             return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
634           },
635           profiler::TraceMeLevel::kInfo);
636       if (VLOG_IS_ON(4)) {
637         for (int i = 0; i < thread_work_sources->size(); ++i) {
638           VLOG(4) << "source id " << i << " "
639                   << (*thread_work_sources)[i]->ToString();
640         }
641       }
642       if (use_sub_thread_pool_) {
643         WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
644       } else {
645         WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
646       }
647     }
648   }
649 }
650 
WaitForWorkInSubThreadPool(bool is_blocking,int sub_thread_pool_id)651 void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
652                                                       int sub_thread_pool_id) {
653   const int kMaxSleepMicros = 250;
654 
655   // The non-blocking thread will just sleep.
656   if (!is_blocking) {
657     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
658     return;
659   }
660 
661   thread_local Waiter waiter;
662   WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
663                &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
664 }
665 
WaitForWork(bool is_blocking,int thread_id,int32 max_blocking_inflight)666 void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
667                                        int32 max_blocking_inflight) {
668   const int kMaxSleepMicros = 250;
669 
670   // The non-blocking thread will just sleep.
671   if (!is_blocking) {
672     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
673     return;
674   }
675 
676   ThreadWorkSource* tws = nullptr;
677   {
678     mutex_lock l(thread_data_[thread_id].mu);
679     if (thread_data_[thread_id].new_version >
680         thread_data_[thread_id].current_version) {
681       thread_data_[thread_id].current_thread_work_sources.swap(
682           thread_data_[thread_id].new_thread_work_sources);
683       thread_data_[thread_id].current_version =
684           thread_data_[thread_id].new_version;
685     }
686     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
687         thread_data_[thread_id].current_thread_work_sources.get();
688     while (!cancelled_ && thread_work_sources->empty()) {
689       // Wait until there is new request
690       thread_data_[thread_id].sources_not_empty.wait(l);
691       if (thread_data_[thread_id].new_version >
692           thread_data_[thread_id].current_version) {
693         thread_data_[thread_id].current_thread_work_sources.swap(
694             thread_data_[thread_id].new_thread_work_sources);
695         thread_data_[thread_id].current_version =
696             thread_data_[thread_id].new_version;
697         thread_work_sources =
698             thread_data_[thread_id].current_thread_work_sources.get();
699       }
700     }
701     if (cancelled_) {
702       return;
703     }
704     tws = (*thread_work_sources)[0];
705   }
706 
707   if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) {
708     // Sleep to reduce contention in PropagateOutputs
709     Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
710   }
711   tws->WaitForWork(kMaxSleepMicros);
712 }
713 
714 }  // namespace internal
715 
716 // Contains the concrete implementation of the RunHandler.
717 // Externally visible RunHandler class simply forwards the work to this one.
718 class RunHandler::Impl {
719  public:
720   explicit Impl(RunHandlerPool::Impl* pool_impl);
721 
~Impl()722   ~Impl() {}
723 
thread_pool_interface()724   thread::ThreadPoolInterface* thread_pool_interface() {
725     return thread_pool_interface_.get();
726   }
727 
728   // Stores now time (in microseconds) since unix epoch when the handler is
729   // requested via RunHandlerPool::Get().
start_time_us() const730   uint64 start_time_us() const { return start_time_us_; }
step_id() const731   int64 step_id() const { return step_id_; }
732   void ScheduleInterOpClosure(std::function<void()> fn);
733   void ScheduleIntraOpClosure(std::function<void()> fn);
734 
735   void Reset(int64 step_id,
736              const RunOptions::Experimental::RunHandlerPoolOptions& options);
737 
pool_impl()738   RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
739 
tws()740   internal::ThreadWorkSource* tws() { return &tws_; }
741 
priority()742   int64 priority() { return options_.priority(); }
743 
744  private:
745   class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
746    public:
ThreadPoolInterfaceWrapper(Impl * run_handler_impl)747     explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
748         : run_handler_impl_(run_handler_impl) {}
~ThreadPoolInterfaceWrapper()749     ~ThreadPoolInterfaceWrapper() override {}
750     void Schedule(std::function<void()> fn) override;
751     int NumThreads() const override;
752     int CurrentThreadId() const override;
753 
754    private:
755     RunHandler::Impl* run_handler_impl_ = nullptr;
756   };
757 
758   RunHandlerPool::Impl* pool_impl_;  // NOT OWNED.
759   uint64 start_time_us_;
760   int64 step_id_;
761   std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
762   internal::ThreadWorkSource tws_;
763   RunOptions::Experimental::RunHandlerPoolOptions options_;
764 };
765 
766 // Contains shared state across all run handlers present in the pool. Also
767 // responsible for pool management decisions.
768 // This class is thread safe.
769 class RunHandlerPool::Impl {
770  public:
Impl(int num_inter_op_threads,int num_intra_op_threads)771   explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
772       : max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
773             "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
774         waiters_mu_(
775             ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
776         queue_waiters_(
777             ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
778         run_handler_thread_pool_(new internal::RunHandlerThreadPool(
779             num_inter_op_threads, num_intra_op_threads, Env::Default(),
780             ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
781             &queue_waiters_)),
782         iterations_(0),
783         version_(0),
784         sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
785             "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
786             std::vector<double>({1}))) {
787     VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
788     free_handlers_.reserve(max_handlers_);
789     handlers_.reserve(max_handlers_);
790     for (int i = 0; i < max_handlers_; ++i) {
791       handlers_.emplace_back(new RunHandler::Impl(this));
792       free_handlers_.push_back(handlers_.back().get());
793     }
794     queue_waiters_.resize(
795         ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
796     waiters_mu_.resize(
797         ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
798     for (auto& queue_waiter : queue_waiters_) {
799       queue_waiter.next = &queue_waiter;
800       queue_waiter.prev = &queue_waiter;
801     }
802     run_handler_thread_pool_->Start();
803   }
804 
~Impl()805   ~Impl() {
806     // Sanity check that all handlers have been returned back to the pool before
807     // destruction.
808     DCHECK_EQ(handlers_.size(), max_handlers_);
809     DCHECK_EQ(free_handlers_.size(), handlers_.size());
810     DCHECK_EQ(sorted_active_handlers_.size(), 0);
811     // Stop the threads in run_handler_thread_pool_ before freeing other
812     // pointers. Otherwise a thread may try to access a pointer after the
813     // pointer has been freed.
814     run_handler_thread_pool_.reset();
815   }
816 
run_handler_thread_pool()817   internal::RunHandlerThreadPool* run_handler_thread_pool() {
818     return run_handler_thread_pool_.get();
819   }
820 
has_free_handler()821   bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
822     return !free_handlers_.empty();
823   }
824 
Get(int64 step_id,int64 timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)825   std::unique_ptr<RunHandler> Get(
826       int64 step_id, int64 timeout_in_ms,
827       const RunOptions::Experimental::RunHandlerPoolOptions& options)
828       TF_LOCKS_EXCLUDED(mu_) {
829     thread_local std::unique_ptr<
830         Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
831         thread_work_sources =
832             std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
833                 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
834                     static_cast<int32>(ParamFromEnvWithDefault(
835                         "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
836                         kMaxConcurrentHandlers))));
837     uint64 version;
838     int num_active_requests;
839     RunHandler::Impl* handler_impl;
840     {
841       mutex_lock l(mu_);
842       if (!has_free_handler()) {
843         profiler::TraceMe activity(
844             [&] {
845               return strings::StrCat("WaitingForHandler#step_id=", step_id,
846                                      "#");
847             },
848             profiler::TraceMeLevel::kInfo);
849         if (timeout_in_ms == 0) {
850           mu_.Await(Condition(this, &Impl::has_free_handler));
851         } else if (!mu_.AwaitWithDeadline(
852                        Condition(this, &Impl::has_free_handler),
853                        EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) {
854           return nullptr;
855         }
856       }
857       // Remove the last entry from free_handlers_ and add to the end of
858       // sorted_active_handlers_.
859       handler_impl = free_handlers_.back();
860       handler_impl->Reset(step_id, options);
861       free_handlers_.pop_back();
862 
863       num_active_requests = sorted_active_handlers_.size() + 1;
864       thread_work_sources->resize(num_active_requests);
865       int priority = options.priority();
866       auto it = sorted_active_handlers_.cbegin();
867       bool new_handler_inserted = false;
868       for (int i = 0; i < num_active_requests; ++i) {
869         if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
870                                       priority > (*it)->priority())) {
871           sorted_active_handlers_.insert(it, handler_impl);
872           new_handler_inserted = true;
873           // Point to the newly added handler.
874           --it;
875         }
876         (*thread_work_sources)[i] = (*it)->tws();
877         ++it;
878       }
879       version = ++version_;
880     }
881     RecomputePoolStats(num_active_requests, version, *thread_work_sources);
882     return WrapUnique<RunHandler>(new RunHandler(handler_impl));
883   }
884 
ReleaseHandler(RunHandler::Impl * handler)885   void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
886     mutex_lock l(mu_);
887     DCHECK_GT(sorted_active_handlers_.size(), 0);
888 
889     CHECK_EQ(handler->tws()->TaskQueueSize(true), 0);   // Crash OK.
890     CHECK_EQ(handler->tws()->TaskQueueSize(false), 0);  // Crash OK.
891 
892     uint64 now = tensorflow::EnvTime::NowMicros();
893     double elapsed = (now - handler->start_time_us()) / 1000.0;
894     time_hist_.Add(elapsed);
895 
896     // Erase from and update sorted_active_handlers_. Add it to the end of
897     // free_handlers_.
898     auto iter = std::find(sorted_active_handlers_.begin(),
899                           sorted_active_handlers_.end(), handler);
900     DCHECK(iter != sorted_active_handlers_.end())
901         << "Unexpected handler: " << handler
902         << " is being requested for release";
903 
904     // Remove this handler from this list and add it to the list of free
905     // handlers.
906     sorted_active_handlers_.erase(iter);
907     free_handlers_.push_back(handler);
908     DCHECK_LE(free_handlers_.size(), max_handlers_);
909     LogInfo();
910 
911     // We do not recompute pool stats during release. The side effect is that
912     // there may be empty thread work sources in the queue. However, any new
913     // requests will trigger recomputation.
914   }
915 
GetActiveHandlerPrioritiesForTesting()916   std::vector<int64> GetActiveHandlerPrioritiesForTesting()
917       TF_LOCKS_EXCLUDED(mu_) {
918     mutex_lock l(mu_);
919     std::vector<int64> ret;
920     for (const auto& handler_impl : sorted_active_handlers_) {
921       ret.push_back(handler_impl->priority());
922     }
923     return ret;
924   }
925 
926  private:
927   void RecomputePoolStats(
928       int num_active_requests, uint64 version,
929       const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
930           thread_work_sources);
931 
932   void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
933 
934   // Maximum number of handlers pre-created during pool construction time. The
935   // number has been chosen expecting each handler might at least want 1
936   // inter-op thread for execution (during compute intensive workloads like
937   // inference).
938   const int max_handlers_;
939 
940   Eigen::MaxSizeVector<mutex> waiters_mu_;
941   Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
942 
943   std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
944   // Thread compatible part used only by lock under RunHandlerPool.
945   // Handlers are sorted by start time.
946   // TODO(azaks): sort by the remaining latency budget.
947   // TODO(chaox): Consider other data structure for maintaining the sorted
948   // active handlers if the searching overhead(currently O(n)) becomes the
949   // bottleneck.
950   std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
951   std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
952   std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
953 
954   // Histogram of elapsed runtime of every handler (in ms).
955   histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
956 
957   int64 iterations_ TF_GUARDED_BY(mu_);
958   mutex mu_;
959   int64 version_ TF_GUARDED_BY(mu_);
960   const std::vector<double> sub_thread_pool_end_request_percentage_;
961 };
962 
RecomputePoolStats(int num_active_requests,uint64 version,const Eigen::MaxSizeVector<internal::ThreadWorkSource * > & thread_work_sources)963 void RunHandlerPool::Impl::RecomputePoolStats(
964     int num_active_requests, uint64 version,
965     const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
966         thread_work_sources) {
967   if (num_active_requests == 0) return;
968 
969   int sub_thread_pool_id = 0;
970   for (int i = 0; i < num_active_requests; ++i) {
971     while (
972         sub_thread_pool_id <
973             sub_thread_pool_end_request_percentage_.size() - 1 &&
974         i >= num_active_requests *
975                  sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
976       sub_thread_pool_id++;
977     }
978     thread_work_sources[i]->SetWaiter(version,
979                                       &queue_waiters_[sub_thread_pool_id],
980                                       &waiters_mu_[sub_thread_pool_id]);
981   }
982 
983   int num_threads = run_handler_thread_pool()->NumThreads();
984   int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
985   int num_non_blocking_threads = num_threads - num_blocking_threads;
986 
987   std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution(
988       num_active_requests, num_blocking_threads);
989   for (int i = 0; i < num_blocking_threads; ++i) {
990     VLOG(2) << "Set work for tid=" << i
991             << " with start_request_idx=" << request_idx_list[i];
992     run_handler_thread_pool()->SetThreadWorkSources(
993         i, request_idx_list[i], version, thread_work_sources);
994   }
995 
996   request_idx_list = ChooseRequestsWithExponentialDistribution(
997       num_active_requests, num_non_blocking_threads);
998   for (int i = 0; i < num_non_blocking_threads; ++i) {
999     VLOG(2) << "Set work for tid=" << (i + num_blocking_threads)
1000             << " with start_request_idx=" << request_idx_list[i];
1001     run_handler_thread_pool()->SetThreadWorkSources(
1002         i + num_blocking_threads, request_idx_list[i], version,
1003         thread_work_sources);
1004   }
1005 }
1006 
LogInfo()1007 void RunHandlerPool::Impl::LogInfo() {
1008   if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
1009     int num_active_requests = sorted_active_handlers_.size();
1010     VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
1011     VLOG(1) << "Active session runs: " << num_active_requests;
1012     uint64 now = tensorflow::Env::Default()->NowMicros();
1013     string times_str = "";
1014     string ids_str = "";
1015     auto it = sorted_active_handlers_.cbegin();
1016     for (int i = 0; i < num_active_requests; ++i) {
1017       if (i > 0) {
1018         times_str += " ";
1019         ids_str += " ";
1020       }
1021 
1022       times_str +=
1023           strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms.");
1024       ids_str += strings::StrCat((*it)->tws()->GetTracemeId());
1025       ++it;
1026     }
1027     VLOG(1) << "Elapsed times are: " << times_str;
1028     VLOG(1) << "Step ids are: " << ids_str;
1029   }
1030 }
1031 
1032 // It is important to return a value such as:
1033 // CurrentThreadId() in [0, NumThreads)
NumThreads() const1034 int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
1035   return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
1036 }
1037 
CurrentThreadId() const1038 int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
1039   return run_handler_impl_->pool_impl_->run_handler_thread_pool()
1040       ->CurrentThreadId();
1041 }
1042 
Schedule(std::function<void ()> fn)1043 void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
1044     std::function<void()> fn) {
1045   return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
1046 }
1047 
Impl(RunHandlerPool::Impl * pool_impl)1048 RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
1049     : pool_impl_(pool_impl) {
1050   thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
1051   Reset(0, RunOptions::Experimental::RunHandlerPoolOptions());
1052 }
1053 
ScheduleInterOpClosure(std::function<void ()> fn)1054 void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
1055   VLOG(3) << "Scheduling inter work for  " << tws()->GetTracemeId();
1056   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
1057                                                         std::move(fn));
1058 }
1059 
ScheduleIntraOpClosure(std::function<void ()> fn)1060 void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
1061   VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1062   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1063                                                         std::move(fn));
1064 }
1065 
Reset(int64 step_id,const RunOptions::Experimental::RunHandlerPoolOptions & options)1066 void RunHandler::Impl::Reset(
1067     int64 step_id,
1068     const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1069   start_time_us_ = tensorflow::Env::Default()->NowMicros();
1070   step_id_ = step_id;
1071   options_ = options;
1072   tws_.SetTracemeId(step_id);
1073 }
1074 
RunHandlerPool(int num_inter_op_threads)1075 RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
1076     : impl_(new Impl(num_inter_op_threads, 0)) {}
1077 
RunHandlerPool(int num_inter_op_threads,int num_intra_op_threads)1078 RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
1079                                int num_intra_op_threads)
1080     : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
1081 
~RunHandlerPool()1082 RunHandlerPool::~RunHandlerPool() {}
1083 
Get(int64 step_id,int64 timeout_in_ms,const RunOptions::Experimental::RunHandlerPoolOptions & options)1084 std::unique_ptr<RunHandler> RunHandlerPool::Get(
1085     int64 step_id, int64 timeout_in_ms,
1086     const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1087   return impl_->Get(step_id, timeout_in_ms, options);
1088 }
1089 
GetActiveHandlerPrioritiesForTesting() const1090 std::vector<int64> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1091     const {
1092   return impl_->GetActiveHandlerPrioritiesForTesting();
1093 }
1094 
RunHandler(Impl * impl)1095 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1096 
ScheduleInterOpClosure(std::function<void ()> fn)1097 void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
1098   impl_->ScheduleInterOpClosure(std::move(fn));
1099 }
1100 
AsIntraThreadPoolInterface()1101 thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
1102   return impl_->thread_pool_interface();
1103 }
1104 
~RunHandler()1105 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1106 
1107 }  // namespace tensorflow
1108