• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
18 
19 #include "tensorflow/core/lib/core/threadpool.h"
20 #include "tensorflow/core/lib/histogram/histogram.h"
21 #include "tensorflow/core/platform/context.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/thread_annotations.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 
26 namespace Eigen {
27 struct ThreadPoolDevice;
28 }
29 
30 namespace tensorflow {
31 
32 class RunHandler;
33 
34 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
35 // that can be used for tracking inter-op work for a given Session::Run().
36 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
37 // 'active' when its unique_ptr is returned by Get() and is being used by a
38 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
39 //
40 // Expected usage:
41 //
42 // * Create a single RunHandlerPool (say run_handler_pool_).
43 //
44 // * When a Session::Run() is invoked, obtain a handler by:
45 // auto handler = run_handler_pool_->Get();
46 //
47 // * Use handler for scheduling all inter-op work by:
48 // handler->ScheduleInterOpClosure(closure);
49 //
50 // This class is thread safe.
51 class RunHandlerPool {
52  public:
53   explicit RunHandlerPool(int num_inter_op_threads);
54 
55   RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads);
56   ~RunHandlerPool();
57 
58   // Returns an inactive RunHandler from the pool.
59   //
60   // RunHandlers in RunHandlerPool are initially 'inactive'.
61   // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
62   // and is being used by a client.  It becomes 'inactive' once more when the
63   // unique_ptr is destroyed.
64   //
65   // Will block unless there is an inactive handler.
66   std::unique_ptr<RunHandler> Get(
67       int64 step_id = 0, int64 timeout_in_ms = 0,
68       const RunOptions::Experimental::RunHandlerPoolOptions& options =
69           RunOptions::Experimental::RunHandlerPoolOptions());
70 
71   // Get the priorities for active handlers. The return result is with the same
72   // order of the active handler list.
73   std::vector<int64> GetActiveHandlerPrioritiesForTesting() const;
74 
75  private:
76   class Impl;
77   friend class RunHandler;
78 
79   std::unique_ptr<Impl> impl_;
80 };
81 
82 // RunHandler can be used to schedule inter/intra-op closures to run on a global
83 // pool shared across all Session::Run(s). The closures are enqueued to a
84 // handler specific queue, from which the work is stolen in a priority order
85 // (time of the Get() call).
86 //
87 // It can only be created via RunHandlerPool::Get().
88 //
89 // This class can be used instead of directly scheduling closures on a global
90 // pool since it maintains a global view across all sessions and optimizes pool
91 // scheduling to improve (median and tail) latency.
92 //
93 // This class is thread safe.
94 class RunHandler {
95  public:
96   void ScheduleInterOpClosure(std::function<void()> fn);
97   thread::ThreadPoolInterface* AsIntraThreadPoolInterface();
98 
99   ~RunHandler();
100 
101  private:
102   class Impl;
103   friend class RunHandlerPool::Impl;
104 
105   explicit RunHandler(Impl* impl);
106 
107   Impl* impl_;  // NOT OWNED.
108 };
109 
110 namespace internal {
111 
112 // TODO(azaks): Refactor with thread:ThreadPool
113 class RunHandlerEnvironment {
114   typedef Thread EnvThread;
115   struct TaskImpl {
116     std::function<void()> f;
117     Context context;
118     uint64 trace_id;
119   };
120   Env* const env_;
121   const ThreadOptions thread_options_;
122   const string name_;
123 
124  public:
125   struct Task {
126     std::unique_ptr<TaskImpl> f;
127   };
128 
129   RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
130                         const string& name);
131 
132   EnvThread* CreateThread(std::function<void()> f);
133 
134   Task CreateTask(std::function<void()> f);
135 
136   void ExecuteTask(const Task& t);
137 };
138 
139 typedef typename RunHandlerEnvironment::Task Task;
140 typedef Eigen::RunQueue<Task, 1024> Queue;
141 
142 // To reduce cache misses, we use a doubly-linked list of Waiter structs and
143 // queue them in LIFO order rather than the FIFO order used by a single
144 // condition variable.
145 struct Waiter {
WaiterWaiter146   Waiter() {
147     next = this;
148     prev = this;
149   }
150   condition_variable cv;
151   mutex mu;
152   Waiter* next;
153   Waiter* prev;
154 };
155 
156 class ThreadWorkSource {
157  public:
158   ThreadWorkSource();
159 
160   ~ThreadWorkSource();
161 
162   Task EnqueueTask(Task t, bool is_blocking);
163 
164   Task PopBlockingTask();
165 
166   Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
167 
168   void WaitForWork(int max_sleep_micros);
169 
170   int TaskQueueSize(bool is_blocking);
171 
172   int64 GetTracemeId();
173 
174   void SetTracemeId(int64 value);
175 
176   void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
177 
178   int64 GetInflightTaskCount(bool is_blocking);
179 
180   void IncrementInflightTaskCount(bool is_blocking);
181 
182   void DecrementInflightTaskCount(bool is_blocking);
183 
184   unsigned NonBlockingWorkShardingFactor();
185 
186   std::string ToString();
187 
188  private:
189   struct NonBlockingQueue {
190     mutex queue_op_mu;
191     char pad[128];
192     Queue queue;
193   };
194 
195   int32 non_blocking_work_sharding_factor_;
196   Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
197 
198   std::atomic<int64> blocking_inflight_;
199   std::atomic<int64> non_blocking_inflight_;
200 
201   Queue blocking_work_queue_;
202   mutex blocking_queue_op_mu_;
203   char pad_[128];
204   mutex waiters_mu_;
205   Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
206   std::atomic<int64> traceme_id_;
207 
208   mutex run_handler_waiter_mu_;
209   uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
210   mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
211   Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
212 };
213 
214 class RunHandlerThreadPool {
215  public:
216   struct PerThread {
PerThreadPerThread217     constexpr PerThread() : pool(nullptr), thread_id(-1) {}
218     RunHandlerThreadPool* pool;  // Parent pool, or null for normal threads.
219     int thread_id;               // Worker thread index in pool.
220   };
221 
222   RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
223                        Env* env, const ThreadOptions& thread_options,
224                        const string& name,
225                        Eigen::MaxSizeVector<mutex>* waiters_mu,
226                        Eigen::MaxSizeVector<Waiter>* queue_waiters);
227 
228   ~RunHandlerThreadPool();
229 
230   void Start();
231 
232   void StartOneThreadForTesting();
233 
234   void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
235                       std::function<void()> fn);
236 
237   // Set work queues from which the thread 'tid' can steal its work.
238   // The request with start_request_idx will be attempted first. Other requests
239   // will be attempted in FIFO order based on their arrival time.
240   void SetThreadWorkSources(
241       int tid, int start_request_idx, uint64 version,
242       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
243 
244   PerThread* GetPerThread();
245 
246   int CurrentThreadId() const;
247 
248   int NumThreads() const;
249 
250   int NumBlockingThreads() const;
251 
252   int NumNonBlockingThreads() const;
253 
254   void WorkerLoop(int thread_id, bool may_steal_blocking_work);
255 
256   // Search tasks from Requets range searching_range_start to
257   // searching_range_end. If there is no tasks in the search range and
258   // may_steal_blocking_work is true, then search from all requests.
259   Task FindTask(
260       int searching_range_start, int searching_range_end, int thread_id,
261       int sub_thread_pool_id, int max_blocking_inflight,
262       bool may_steal_blocking_work,
263       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
264       bool* task_from_blocking_queue, ThreadWorkSource** tws);
265 
266   void WaitForWork(bool is_blocking, int thread_id,
267                    int32 max_blocking_inflight);
268 
269   void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
270 
271  private:
272   struct ThreadData {
273     ThreadData();
274     mutex mu;
275     uint64 new_version;
276     condition_variable sources_not_empty;
277     std::unique_ptr<Thread> thread;
278     int current_index;
279     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
280         new_thread_work_sources TF_GUARDED_BY(mu);
281 
282     uint64 current_version;
283     // Should only be accessed by one thread.
284     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
285         current_thread_work_sources;
286 
287     int sub_thread_pool_id;
288   };
289 
290   const int num_threads_;
291   const int num_blocking_threads_;
292   const int num_non_blocking_threads_;
293   Eigen::MaxSizeVector<ThreadData> thread_data_;
294   internal::RunHandlerEnvironment env_;
295   std::atomic<bool> cancelled_;
296   string name_;
297   Eigen::MaxSizeVector<mutex>* waiters_mu_;
298   Eigen::MaxSizeVector<Waiter>* queue_waiters_;
299 
300   bool use_sub_thread_pool_;
301   std::vector<int> num_threads_in_sub_thread_pool_;
302 
303   // Threads in each sub thread pool will search tasks from the given
304   // start_request_percentage to end_request_percentage in a round robin
305   // fashion.
306   std::vector<double> sub_thread_pool_start_request_percentage_;
307   std::vector<double> sub_thread_pool_end_request_percentage_;
308 };
309 
310 }  // namespace internal
311 
312 }  // end namespace tensorflow.
313 
314 #endif  // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
315