1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 18 19 #include "tensorflow/core/lib/core/threadpool.h" 20 #include "tensorflow/core/lib/histogram/histogram.h" 21 #include "tensorflow/core/platform/context.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/thread_annotations.h" 24 #include "tensorflow/core/protobuf/config.pb.h" 25 26 namespace Eigen { 27 struct ThreadPoolDevice; 28 } 29 30 namespace tensorflow { 31 32 class RunHandler; 33 34 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers 35 // that can be used for tracking inter-op work for a given Session::Run(). 36 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes 37 // 'active' when its unique_ptr is returned by Get() and is being used by a 38 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed. 39 // 40 // Expected usage: 41 // 42 // * Create a single RunHandlerPool (say run_handler_pool_). 43 // 44 // * When a Session::Run() is invoked, obtain a handler by: 45 // auto handler = run_handler_pool_->Get(); 46 // 47 // * Use handler for scheduling all inter-op work by: 48 // handler->ScheduleInterOpClosure(closure); 49 // 50 // This class is thread safe. 51 class RunHandlerPool { 52 public: 53 explicit RunHandlerPool(int num_inter_op_threads); 54 55 RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads); 56 ~RunHandlerPool(); 57 58 // Returns an inactive RunHandler from the pool. 59 // 60 // RunHandlers in RunHandlerPool are initially 'inactive'. 61 // A RunHandler becomes 'active' when its unique_ptr its returned by Get() 62 // and is being used by a client. It becomes 'inactive' once more when the 63 // unique_ptr is destroyed. 64 // 65 // Will block unless there is an inactive handler. 66 std::unique_ptr<RunHandler> Get( 67 int64 step_id = 0, int64 timeout_in_ms = 0, 68 const RunOptions::Experimental::RunHandlerPoolOptions& options = 69 RunOptions::Experimental::RunHandlerPoolOptions()); 70 71 // Get the priorities for active handlers. The return result is with the same 72 // order of the active handler list. 73 std::vector<int64> GetActiveHandlerPrioritiesForTesting() const; 74 75 private: 76 class Impl; 77 friend class RunHandler; 78 79 std::unique_ptr<Impl> impl_; 80 }; 81 82 // RunHandler can be used to schedule inter/intra-op closures to run on a global 83 // pool shared across all Session::Run(s). The closures are enqueued to a 84 // handler specific queue, from which the work is stolen in a priority order 85 // (time of the Get() call). 86 // 87 // It can only be created via RunHandlerPool::Get(). 88 // 89 // This class can be used instead of directly scheduling closures on a global 90 // pool since it maintains a global view across all sessions and optimizes pool 91 // scheduling to improve (median and tail) latency. 92 // 93 // This class is thread safe. 94 class RunHandler { 95 public: 96 void ScheduleInterOpClosure(std::function<void()> fn); 97 thread::ThreadPoolInterface* AsIntraThreadPoolInterface(); 98 99 ~RunHandler(); 100 101 private: 102 class Impl; 103 friend class RunHandlerPool::Impl; 104 105 explicit RunHandler(Impl* impl); 106 107 Impl* impl_; // NOT OWNED. 108 }; 109 110 namespace internal { 111 112 // TODO(azaks): Refactor with thread:ThreadPool 113 class RunHandlerEnvironment { 114 typedef Thread EnvThread; 115 struct TaskImpl { 116 std::function<void()> f; 117 Context context; 118 uint64 trace_id; 119 }; 120 Env* const env_; 121 const ThreadOptions thread_options_; 122 const string name_; 123 124 public: 125 struct Task { 126 std::unique_ptr<TaskImpl> f; 127 }; 128 129 RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, 130 const string& name); 131 132 EnvThread* CreateThread(std::function<void()> f); 133 134 Task CreateTask(std::function<void()> f); 135 136 void ExecuteTask(const Task& t); 137 }; 138 139 typedef typename RunHandlerEnvironment::Task Task; 140 typedef Eigen::RunQueue<Task, 1024> Queue; 141 142 // To reduce cache misses, we use a doubly-linked list of Waiter structs and 143 // queue them in LIFO order rather than the FIFO order used by a single 144 // condition variable. 145 struct Waiter { WaiterWaiter146 Waiter() { 147 next = this; 148 prev = this; 149 } 150 condition_variable cv; 151 mutex mu; 152 Waiter* next; 153 Waiter* prev; 154 }; 155 156 class ThreadWorkSource { 157 public: 158 ThreadWorkSource(); 159 160 ~ThreadWorkSource(); 161 162 Task EnqueueTask(Task t, bool is_blocking); 163 164 Task PopBlockingTask(); 165 166 Task PopNonBlockingTask(int start_index, bool search_from_all_queue); 167 168 void WaitForWork(int max_sleep_micros); 169 170 int TaskQueueSize(bool is_blocking); 171 172 int64 GetTracemeId(); 173 174 void SetTracemeId(int64 value); 175 176 void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex); 177 178 int64 GetInflightTaskCount(bool is_blocking); 179 180 void IncrementInflightTaskCount(bool is_blocking); 181 182 void DecrementInflightTaskCount(bool is_blocking); 183 184 unsigned NonBlockingWorkShardingFactor(); 185 186 std::string ToString(); 187 188 private: 189 struct NonBlockingQueue { 190 mutex queue_op_mu; 191 char pad[128]; 192 Queue queue; 193 }; 194 195 int32 non_blocking_work_sharding_factor_; 196 Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_; 197 198 std::atomic<int64> blocking_inflight_; 199 std::atomic<int64> non_blocking_inflight_; 200 201 Queue blocking_work_queue_; 202 mutex blocking_queue_op_mu_; 203 char pad_[128]; 204 mutex waiters_mu_; 205 Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); 206 std::atomic<int64> traceme_id_; 207 208 mutex run_handler_waiter_mu_; 209 uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); 210 mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); 211 Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); 212 }; 213 214 class RunHandlerThreadPool { 215 public: 216 struct PerThread { PerThreadPerThread217 constexpr PerThread() : pool(nullptr), thread_id(-1) {} 218 RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. 219 int thread_id; // Worker thread index in pool. 220 }; 221 222 RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, 223 Env* env, const ThreadOptions& thread_options, 224 const string& name, 225 Eigen::MaxSizeVector<mutex>* waiters_mu, 226 Eigen::MaxSizeVector<Waiter>* queue_waiters); 227 228 ~RunHandlerThreadPool(); 229 230 void Start(); 231 232 void StartOneThreadForTesting(); 233 234 void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, 235 std::function<void()> fn); 236 237 // Set work queues from which the thread 'tid' can steal its work. 238 // The request with start_request_idx will be attempted first. Other requests 239 // will be attempted in FIFO order based on their arrival time. 240 void SetThreadWorkSources( 241 int tid, int start_request_idx, uint64 version, 242 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources); 243 244 PerThread* GetPerThread(); 245 246 int CurrentThreadId() const; 247 248 int NumThreads() const; 249 250 int NumBlockingThreads() const; 251 252 int NumNonBlockingThreads() const; 253 254 void WorkerLoop(int thread_id, bool may_steal_blocking_work); 255 256 // Search tasks from Requets range searching_range_start to 257 // searching_range_end. If there is no tasks in the search range and 258 // may_steal_blocking_work is true, then search from all requests. 259 Task FindTask( 260 int searching_range_start, int searching_range_end, int thread_id, 261 int sub_thread_pool_id, int max_blocking_inflight, 262 bool may_steal_blocking_work, 263 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources, 264 bool* task_from_blocking_queue, ThreadWorkSource** tws); 265 266 void WaitForWork(bool is_blocking, int thread_id, 267 int32 max_blocking_inflight); 268 269 void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); 270 271 private: 272 struct ThreadData { 273 ThreadData(); 274 mutex mu; 275 uint64 new_version; 276 condition_variable sources_not_empty; 277 std::unique_ptr<Thread> thread; 278 int current_index; 279 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 280 new_thread_work_sources TF_GUARDED_BY(mu); 281 282 uint64 current_version; 283 // Should only be accessed by one thread. 284 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 285 current_thread_work_sources; 286 287 int sub_thread_pool_id; 288 }; 289 290 const int num_threads_; 291 const int num_blocking_threads_; 292 const int num_non_blocking_threads_; 293 Eigen::MaxSizeVector<ThreadData> thread_data_; 294 internal::RunHandlerEnvironment env_; 295 std::atomic<bool> cancelled_; 296 string name_; 297 Eigen::MaxSizeVector<mutex>* waiters_mu_; 298 Eigen::MaxSizeVector<Waiter>* queue_waiters_; 299 300 bool use_sub_thread_pool_; 301 std::vector<int> num_threads_in_sub_thread_pool_; 302 303 // Threads in each sub thread pool will search tasks from the given 304 // start_request_percentage to end_request_percentage in a round robin 305 // fashion. 306 std::vector<double> sub_thread_pool_start_request_percentage_; 307 std::vector<double> sub_thread_pool_end_request_percentage_; 308 }; 309 310 } // namespace internal 311 312 } // end namespace tensorflow. 313 314 #endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 315