1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <vector> 25 26 #include "tensorflow/core/common_runtime/device_factory.h" 27 #include "tensorflow/core/common_runtime/device_mgr.h" 28 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 29 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 30 #include "tensorflow/core/common_runtime/function.h" 31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 32 #include "tensorflow/core/example/example.pb.h" 33 #include "tensorflow/core/platform/env.h" 34 #ifndef __ANDROID__ 35 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 36 #include "tensorflow/core/distributed_runtime/server_lib.h" 37 #include "tensorflow/core/distributed_runtime/worker_cache.h" 38 #endif 39 #include "tensorflow/core/framework/collective.h" 40 #include "tensorflow/core/framework/log_memory.h" 41 #include "tensorflow/core/framework/rendezvous.h" 42 #include "tensorflow/core/lib/core/stringpiece.h" 43 #include "tensorflow/core/lib/core/threadpool.h" 44 #include "tensorflow/core/lib/gtl/flatmap.h" 45 #include "tensorflow/core/lib/gtl/flatset.h" 46 #include "tensorflow/core/lib/gtl/inlined_vector.h" 47 #include "tensorflow/core/lib/gtl/map_util.h" 48 #include "tensorflow/core/lib/gtl/stl_util.h" 49 #include "tensorflow/core/platform/fingerprint.h" 50 #include "tensorflow/core/platform/mutex.h" 51 #include "tensorflow/core/platform/thread_annotations.h" 52 #include "tensorflow/core/public/session_options.h" 53 #include "tensorflow/core/public/version.h" 54 55 namespace tensorflow { 56 57 // Note: there's a copy enum in eager/c_api.h. It should be kept in sync. 58 enum ContextDevicePlacementPolicy { 59 // Running operations with input tensors on the wrong device will fail. 60 DEVICE_PLACEMENT_EXPLICIT = 0, 61 // Copy the tensor to the right device but log a warning. 62 DEVICE_PLACEMENT_WARN = 1, 63 // Silently copy the tensor, which has a performance cost since the operation 64 // will be blocked till the copy completes. This is the default policy. 65 DEVICE_PLACEMENT_SILENT = 2, 66 // Placement policy which silently copies int32 tensors but not other dtypes. 67 DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, 68 }; 69 70 class RunMetadataListener { 71 public: ~RunMetadataListener()72 virtual ~RunMetadataListener() {} 73 virtual void BeforeClearRunMetadata() = 0; 74 }; 75 76 class EagerContext { 77 public: 78 // TODO: remove this constructor once we migrate all callers to the next one. 79 EagerContext(const SessionOptions& opts, 80 ContextDevicePlacementPolicy default_policy, bool async, 81 std::unique_ptr<const DeviceMgr> device_mgr, 82 Rendezvous* rendezvous); 83 84 EagerContext(const SessionOptions& opts, 85 ContextDevicePlacementPolicy default_policy, bool async, 86 const DeviceMgr* device_mgr, bool device_mgr_owned, 87 Rendezvous* rendezvous); 88 89 ~EagerContext(); 90 91 // Returns the function library runtime for the given device. func_lib(Device * d)92 FunctionLibraryRuntime* func_lib(Device* d) const { 93 return pflr_->GetFLR(d->name()); 94 } 95 pflr()96 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); } 97 98 // True if running in asynchronous mode. 99 bool Async() const; 100 Executor()101 EagerExecutor* Executor() { return &executor_; } 102 runner()103 std::function<void(std::function<void()>)>* runner() { return &runner_; } 104 105 // Sets whether this thread should run in synchronous or asynchronous mode. 106 Status SetAsyncForThread(bool async); 107 108 // TODO(apassos) make this return a constant reference device_map()109 gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() { 110 return &devices_map_; 111 } 112 113 // TODO(apassos) make this return a constant reference devices()114 std::vector<Device*>* devices() { return &devices_; } prioritized_device_type_list()115 const std::vector<DeviceType>& prioritized_device_type_list() { 116 return prioritized_device_type_list_; 117 } 118 119 // Clears the kernel caches. 120 Status ClearCaches(); 121 122 // Sets the device placement policy for the current thread. 123 void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); 124 125 // Returns the device placement policy for the current thread. 126 ContextDevicePlacementPolicy GetDevicePlacementPolicy(); 127 AsyncWait()128 Status AsyncWait() { return executor_.WaitForAllPendingNodes(); } 129 GetStatus()130 Status GetStatus() { return executor_.status(); } 131 ClearAsyncError()132 void ClearAsyncError() { executor_.ClearError(); } 133 134 bool FindFunctionByName(const string& name); 135 136 Status FindFunctionOpData(const string& name, 137 const tensorflow::OpRegistrationData** op_data); 138 139 const FunctionDef* FindFunctionDef(const string& name); 140 141 Status FindDeviceByName(const string& name, Device** result); 142 HostCPU()143 Device* HostCPU() const { return devices_[0]; } 144 GetGraphCollector()145 GraphCollector* GetGraphCollector() { return &graph_collector_; } 146 NextId()147 uint64 NextId() { return executor_.NextId(); } 148 ExecutorAdd(EagerNode * node)149 void ExecutorAdd(EagerNode* node) { executor_.Add(node); } 150 151 Status AddFunctionDef(const FunctionDef& fdef); 152 153 KernelAndDevice* GetCachedKernel(Fprint128 cache_key); 154 155 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); 156 LogDevicePlacement()157 bool LogDevicePlacement() const { return log_device_placement_; } LogMemory()158 bool LogMemory() const { return log_memory_; } 159 GetRendezvous()160 Rendezvous* GetRendezvous() const { return rendezvous_; } collective_executor_mgr()161 CollectiveExecutorMgrInterface* collective_executor_mgr() { 162 return (collective_executor_mgr_ != nullptr) 163 ? collective_executor_mgr_.get() 164 : unowned_collective_executor_mgr_; 165 } GetCollectiveExecutorHandle()166 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() { 167 return std::unique_ptr<CollectiveExecutor::Handle>( 168 new CollectiveExecutor::Handle( 169 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/)); 170 } 171 local_device_mgr()172 const tensorflow::DeviceMgr* local_device_mgr() const { 173 return (local_device_manager_ != nullptr) ? local_device_manager_.get() 174 : local_unowned_device_manager_; 175 } remote_device_mgr()176 const tensorflow::DeviceMgr* remote_device_mgr() const { 177 return remote_device_manager_.get(); 178 } 179 180 // TODO(apassos) remove the need for this ReleaseDeviceMgr()181 void ReleaseDeviceMgr() { local_device_manager_.release(); } 182 183 // TODO(apassos) clean up RunMetadata storage. MetadataMu()184 mutex* MetadataMu() LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } 185 bool ShouldStoreStepStats() LOCKS_EXCLUDED(metadata_mu_); 186 void SetShouldStoreStepStats(bool value); 187 bool ShouldStoreGraphs() LOCKS_EXCLUDED(metadata_mu_); 188 void SetShouldStoreGraphs(bool value); RunMetadataProto()189 RunMetadata* RunMetadataProto() { return &run_metadata_; } 190 void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); 191 192 Status RegisterRunMetadataListener(RunMetadataListener* listener) 193 LOCKS_EXCLUDED(metadata_mu_); 194 void ClearRunMetadataListener() LOCKS_EXCLUDED(metadata_mu_); 195 196 void StartStep(); 197 void EndStep(); 198 ScopedStepContainer* StepContainer(); 199 FuncLibDef()200 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } 201 202 #ifndef __ANDROID__ 203 Status GetClientAndContextID(Device* device, eager::EagerClient** client, 204 uint64* context_id); 205 206 // TODO(nareshmodi): Encapsulate remote state into a separate 207 // class/struct. 208 // 209 // Enables the eager context to communicate with remote devices. 210 // 211 // - server: A ServerInterface that exports the tensorflow.WorkerService. 212 // Note that this class expects the server to already have been started. 213 // - remote_eager_workers: A cache from which we can get "EagerClient"s to 214 // communicate with remote eager services. 215 // - remote_device_mgr: A DeviceMgr* which contains all remote devices 216 // (should contain no local devices). 217 // - remote_contexts: A map containing task name to remote context ID. 218 Status InitializeRemote( 219 std::unique_ptr<ServerInterface> server, 220 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 221 std::unique_ptr<DeviceMgr> remote_device_manager, 222 const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, 223 DeviceMgr* local_device_mgr, int keep_alive_secs); 224 HasActiveRemoteContext(uint64 context_id)225 bool HasActiveRemoteContext(uint64 context_id) { 226 return active_remote_contexts_.find(context_id) != 227 active_remote_contexts_.end(); 228 } 229 230 Status StoreCollectiveOpsServer( 231 std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr, 232 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); 233 #endif 234 235 // If true, then tensors should be shipped across processes via the 236 // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used 237 // instead (which in-turn use WorkerService.RecvTensor RPCs). UseSendTensorRPC()238 bool UseSendTensorRPC() { return use_send_tensor_rpc_; } PinSmallOpsToCPU()239 bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } 240 TFEnv()241 tensorflow::Env* TFEnv() const { return env_; } 242 243 // All child threads will be reset() when destructing EagerContext. 244 void AddChildThread(std::unique_ptr<Thread> thread); 245 246 private: 247 void InitDeviceMapAndAsync(); 248 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); 249 250 const ContextDevicePlacementPolicy policy_; 251 252 // Note: we cannot use C++11 thread_local here as there is no concept of a 253 // thread-local-object-local variable in C++11. 254 mutex policy_map_mu_; 255 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy> 256 thread_local_policies_ GUARDED_BY(policy_map_mu_); 257 258 // Only one of the below is set. 259 std::unique_ptr<const DeviceMgr> local_device_manager_; 260 const DeviceMgr* local_unowned_device_manager_; 261 std::unique_ptr<DeviceMgr> remote_device_manager_; 262 263 // Devices owned by device_manager 264 std::vector<Device*> devices_; 265 std::vector<DeviceType> prioritized_device_type_list_; 266 // All devices are not owned. 267 gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_; 268 Rendezvous* rendezvous_; 269 270 mutex functions_mu_; GUARDED_BY(functions_mu_)271 FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){ 272 OpRegistry::Global(), {}}; 273 274 std::unique_ptr<thread::ThreadPool> thread_pool_; 275 276 // One FunctionLibraryRuntime per device. 277 // func_libs[i] is the FunctionLibraryRuntime corresponding to 278 // session->devices[i]. 279 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 280 281 std::function<void(std::function<void()>)> runner_; 282 283 mutex cache_mu_; 284 std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_ 285 GUARDED_BY(cache_mu_); 286 287 // Whether we should compute RunMetadata. 288 std::atomic<bool> should_store_step_stats_{false}; 289 std::atomic<bool> should_store_graphs_{false}; 290 mutex metadata_mu_; 291 RunMetadata run_metadata_ GUARDED_BY(metadata_mu_); 292 RunMetadataListener* metadata_listener_ GUARDED_BY(metadata_mu_) = nullptr; 293 GraphCollector graph_collector_; 294 const bool log_device_placement_; 295 // EagerExecutor for async execution. 296 EagerExecutor executor_; 297 298 // Information related to step containers. 299 std::atomic<int> num_active_steps_; 300 std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_); 301 302 // True if the default value for execution mode is async. Note that this value 303 // can be overridden per thread based on `thread_local_async` overrides. 304 const bool async_default_; 305 mutable mutex async_map_mu_; 306 std::unordered_map<std::thread::id, bool> thread_local_async_ 307 GUARDED_BY(async_map_mu_); 308 309 const bool log_memory_; 310 311 Env* const env_; 312 313 std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_; 314 CollectiveExecutorMgrInterface* unowned_collective_executor_mgr_ = nullptr; 315 316 #ifndef __ANDROID__ 317 void CloseRemoteContexts(); 318 319 // The server_ is not const since we release it when the context is destroyed. 320 // Therefore the server_ object is not marked as const (even though it should 321 // be). 322 std::unique_ptr<ServerInterface> server_; 323 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; 324 325 mutex remote_state_mu_; 326 327 gtl::FlatMap<string, uint64> remote_contexts_; 328 gtl::FlatSet<uint64> active_remote_contexts_; 329 gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> 330 device_to_client_cache_; 331 332 int keep_alive_secs_ GUARDED_BY(remote_state_mu_); 333 std::atomic<int> sleep_for_secs_; 334 335 std::unique_ptr<Thread> keep_alive_thread_; 336 mutex keep_alive_thread_shutdown_mu_; 337 condition_variable keep_alive_thread_cv_; 338 bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false; 339 #endif 340 341 bool use_send_tensor_rpc_; 342 const bool pin_small_ops_to_cpu_; 343 std::vector<std::unique_ptr<tensorflow::Thread>> child_threads_; 344 }; 345 346 } // namespace tensorflow 347 348 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 349