1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <vector> 25 26 // clang-format off 27 // Required for IS_MOBILE_PLATFORM 28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/platform.h" 31 // clang-format on 32 33 #include "tensorflow/core/common_runtime/device_factory.h" 34 #include "tensorflow/core/common_runtime/device_mgr.h" 35 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 37 #include "tensorflow/core/common_runtime/function.h" 38 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 39 #include "tensorflow/core/example/example.pb.h" 40 #include "tensorflow/core/framework/function.h" 41 #include "tensorflow/core/platform/env.h" 42 #include "tensorflow/core/util/device_name_utils.h" 43 #if !defined(IS_MOBILE_PLATFORM) 44 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 45 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" 46 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 47 #include "tensorflow/core/distributed_runtime/server_lib.h" 48 #include "tensorflow/core/distributed_runtime/worker_cache.h" 49 #include "tensorflow/core/distributed_runtime/worker_env.h" 50 #endif // !IS_MOBILE_PLATFORM 51 #include "tensorflow/core/framework/collective.h" 52 #include "tensorflow/core/framework/log_memory.h" 53 #include "tensorflow/core/framework/rendezvous.h" 54 #include "tensorflow/core/lib/core/stringpiece.h" 55 #include "tensorflow/core/lib/core/threadpool.h" 56 #include "tensorflow/core/lib/gtl/flatmap.h" 57 #include "tensorflow/core/lib/gtl/flatset.h" 58 #include "tensorflow/core/lib/gtl/inlined_vector.h" 59 #include "tensorflow/core/lib/gtl/map_util.h" 60 61 #include "tensorflow/core/platform/fingerprint.h" 62 #include "tensorflow/core/platform/mutex.h" 63 #include "tensorflow/core/platform/thread_annotations.h" 64 #include "tensorflow/core/public/session_options.h" 65 #include "tensorflow/core/public/version.h" 66 67 namespace tensorflow { 68 69 namespace eager { 70 // We need this forward declaration because we have circular dependency: 71 // Context -> RemoteMgr -> TensorHandle -> Context. 72 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle. 73 class RemoteMgr; 74 } // namespace eager 75 76 // LINT.IfChange 77 // Note: Keep in sync with exported copy of enum in eager/c_api.h. 78 enum ContextDevicePlacementPolicy { 79 // Running operations with input tensors on the wrong device will fail. 80 DEVICE_PLACEMENT_EXPLICIT = 0, 81 // Copy the tensor to the right device but log a warning. 82 DEVICE_PLACEMENT_WARN = 1, 83 // Silently copy the tensor, which has a performance cost since the operation 84 // will be blocked till the copy completes. This is the default policy. 85 DEVICE_PLACEMENT_SILENT = 2, 86 // Placement policy which silently copies int32 tensors but not other dtypes. 87 DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, 88 }; 89 // LINT.ThenChange(//tensorflow/c/eager/c_api.h) 90 91 // LINT.IfChange 92 // Note: Keep in sync with exported copy of enum in eager/c_api_experimental.h. 93 enum ContextMirroringPolicy { 94 // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle 95 // copies with their own lifetime. 96 MIRRORING_NONE = 0, 97 // Mirroring any remote tensor handles, associating them with the lifetime of 98 // the local TensorHandle. 99 MIRRORING_ALL = 1, 100 }; 101 // LINT.ThenChange(//tensorflow/c/eager/c_api_experimental.h) 102 103 class RunMetadataListener { 104 public: ~RunMetadataListener()105 virtual ~RunMetadataListener() {} 106 virtual void BeforeClearRunMetadata() = 0; 107 }; 108 109 class EagerContext : public core::RefCounted { 110 public: 111 static const uint64 kInvalidContextId = 0; 112 NewContextId()113 static uint64 NewContextId() { 114 uint64 context_id = random::New64(); 115 while (context_id == kInvalidContextId) { 116 context_id = random::New64(); 117 } 118 return context_id; 119 } 120 121 EagerContext(const SessionOptions& opts, 122 ContextDevicePlacementPolicy default_device_placement_policy, 123 ContextMirroringPolicy default_mirroring_policy, bool async, 124 const bool lazy_copy_function_remote_inputs, 125 const DeviceMgr* device_mgr, bool device_mgr_owned, 126 Rendezvous* rendezvous, 127 const CustomKernelCreator* custom_kernel_creator, 128 DistributedFunctionLibraryRuntime* cluster_flr = nullptr); 129 130 ~EagerContext() override; 131 132 // Returns the function library runtime for the given device. func_lib(const Device * d)133 FunctionLibraryRuntime* func_lib(const Device* d) const { 134 return pflr_->GetFLR(d->name()); 135 } 136 pflr()137 ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); } 138 runner()139 std::function<void(std::function<void()>)>* runner() { return &runner_; } 140 141 // Specify a executor for this thread. 142 void SetExecutorForThread(EagerExecutor* executor); 143 prioritized_device_type_list()144 const std::vector<DeviceType>& prioritized_device_type_list() const { 145 return prioritized_device_type_list_; 146 } 147 148 // Clear pending nodes in thread executors and kernel caches. 149 void ClearCachesAndThreadExecutors(); 150 // Clear pending nodes in default executor and kernel caches. 151 void ClearCachesAndDefaultExecutor(); 152 153 // Sets the device placement policy for the current thread. 154 void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); 155 156 // Returns the device placement policy for the current thread. 157 ContextDevicePlacementPolicy GetDevicePlacementPolicy() const; 158 159 // Select an appropriate device for an operation. 160 // 161 // Given the preferred device for the operation, and the list of devices the 162 // operation supports, finds the best suitable device for the operation in 163 // this context. 164 // 165 // The preferred device is specified as a `ParsedName` containing the elements 166 // (details) that the resulting device should match. If there are no such 167 // devices, and the context currently allows soft device placement, a suitable 168 // device not matching `preferred` will be chosen. 169 // 170 // The `dtype` parameter specifies the operation's result data type, if 171 // known. Setting it to DT_INVALID will make this method not use the data type 172 // for its decisions. 173 // 174 // The chosen device is stored in the `device` argument. The argument is not 175 // modified unless this method returns `Status::OK()`. 176 Status SelectDevice(DeviceNameUtils::ParsedName preferred, 177 const PrioritizedDeviceTypeVector& supported, 178 const DataType dtype, Device** device) const; 179 180 // Sets the implicit copy policy for the current thread. 181 void SetThreadLocalMirroringPolicy(ContextMirroringPolicy); 182 183 // Returns the implicit copy policy for the current thread. 184 ContextMirroringPolicy GetMirroringPolicy() const; 185 186 bool MirrorTensors() const; 187 188 bool LazyCopyFunctionRemoteInputs() const; 189 190 bool FindFunctionByName(const string& name) const; 191 192 Status FindFunctionOpData(const string& name, 193 const tensorflow::OpRegistrationData** op_data); 194 195 const FunctionDef* FindFunctionDef(const string& name); 196 HostCPU()197 Device* HostCPU() const { return host_cpu_device_; } CanonicalDevice(Device * d)198 Device* CanonicalDevice(Device* d) const { 199 return HostCPU() == d ? nullptr : d; 200 } 201 GetGraphCollector()202 GraphCollector* GetGraphCollector() { return &graph_collector_; } 203 204 EagerExecutor& Executor(); 205 206 // Add the given `fdef` to the local FunctionLibraryDefinition. And add an 207 // entry to the KernelAndDevice cache for it if it's not exist. 208 Status AddFunctionDef(const FunctionDef& fdef); 209 // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add 210 // it to the local FunctionLibraryDefinition as well, but no need to add it 211 // to the KernelAndDevice cache since they won't be executed as 212 // KernelAndDevices. 213 Status AddFunctionDef(const FunctionDef& fdef, 214 const FunctionDefLibrary& library, 215 const bool add_to_local_only = false); 216 217 Status RemoveFunction(const string& func); 218 219 core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key); 220 221 void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); 222 LogDevicePlacement()223 bool LogDevicePlacement() const { return log_device_placement_; } AllowSoftPlacement()224 bool AllowSoftPlacement() const { return allow_soft_placement_; } LogMemory()225 bool LogMemory() const { return log_memory_; } 226 GetRendezvous()227 Rendezvous* GetRendezvous() const { return rendezvous_; } CreateRendezvous(const int64 step_id)228 Rendezvous* CreateRendezvous(const int64 step_id) const { 229 if (rendezvous_creator_ != nullptr) { 230 return rendezvous_creator_(step_id); 231 } 232 233 #if !defined(IS_MOBILE_PLATFORM) 234 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) { 235 auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id); 236 remote_r->Initialize(worker_session_.get()).IgnoreError(); 237 return remote_r; 238 } 239 #endif 240 241 if (remote_device_mgr() == nullptr) { 242 return new IntraProcessRendezvous(local_device_mgr()); 243 } 244 245 return nullptr; 246 } 247 collective_executor_mgr()248 CollectiveExecutorMgrInterface* collective_executor_mgr() { 249 return collective_executor_mgr_.Get(); 250 } GetCollectiveExecutorHandle()251 std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() { 252 return std::unique_ptr<CollectiveExecutor::Handle>( 253 new CollectiveExecutor::Handle( 254 collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/)); 255 } 256 local_device_mgr()257 const tensorflow::DeviceMgr* local_device_mgr() const { 258 return local_device_manager_.Get(); 259 } remote_device_mgr()260 const tensorflow::DynamicDeviceMgr* remote_device_mgr() const { 261 return remote_device_manager_.Get(); 262 } 263 GetOwnedRemoteDeviceMgr()264 tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() { 265 return remote_device_manager_.GetOwned(); 266 } 267 268 // TODO(apassos) clean up RunMetadata storage. MetadataMu()269 mutex* MetadataMu() LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } 270 bool ShouldStoreGraphs() LOCKS_EXCLUDED(metadata_mu_); 271 void SetShouldStoreGraphs(bool value); RunMetadataProto()272 RunMetadata* RunMetadataProto() { return &run_metadata_; } 273 void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); 274 275 void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices); 276 277 void StartStep(); 278 void EndStep(); 279 ScopedStepContainer* StepContainer(); 280 FuncLibDef()281 FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } 282 283 #if !defined(IS_MOBILE_PLATFORM) 284 // Assign the EagerClient pointer to `client` based on the given device / task 285 // name, and increment the refcount of the client. The reference ownership is 286 // transferred to the caller, and the unref should automatically happen when 287 // destructing the RefCountPtr object at the caller's side. 288 // `client` must not be initialized or holding a reference of another object 289 // before calling this method. 290 Status GetClient(Device* device, 291 core::RefCountPtr<eager::EagerClient>* client); 292 Status GetClient(const DeviceNameUtils::ParsedName& device_name, 293 core::RefCountPtr<eager::EagerClient>* client); 294 Status GetClient(const string& remote_task, 295 core::RefCountPtr<eager::EagerClient>* client); 296 297 uint64 GetContextId(); 298 uint64 GetContextViewId(); 299 void IncrementContextViewId(); 300 301 // TODO(nareshmodi): Encapsulate remote state into a separate 302 // class/struct. 303 // 304 // Enables the eager context to communicate with remote devices. When 305 // initializing with this method, this context will be the master context, 306 // which will kill all its slaves in shutdown. 307 // 308 // - server: A ServerInterface that exports the tensorflow.WorkerService. 309 // Note that this class expects the server to already have been started. 310 // - remote_eager_workers: A cache from which we can get "EagerClient"s to 311 // communicate with remote eager services. 312 // - remote_device_mgr: A DeviceMgr* which contains all remote devices 313 // (should contain no local devices). 314 // - remote_contexts: A vector containing task names. 315 Status InitializeRemoteMaster( 316 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, 317 std::shared_ptr<WorkerSession> worker_session, 318 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 319 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, 320 const std::vector<string>& remote_contexts, uint64 context_id, 321 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, 322 DistributedFunctionLibraryRuntime* cluster_flr, 323 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> 324 remote_mgr); 325 326 // Update an existing master context with a new set of remote workers (i.e., a 327 // new "view" of cluster membership. Similar to InitializeRemoteMaster but 328 // this will keep the current context_id and increment a context_view_id, will 329 // keep the current resource manager so that resources from the previous view 330 // can still be accessed, and will automatically register existing functions 331 // if there are newly added hosts. 332 Status UpdateRemoteMaster( 333 WorkerEnv* worker_env, 334 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 335 const std::vector<string>& add_remote_contexts, 336 const std::vector<string>& remove_remote_contexts, uint64 context_id, 337 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, 338 DistributedFunctionLibraryRuntime* cluster_flr); 339 340 // Similar with InitializeRemoteMaster but this context will not kill remote 341 // contexts in shutdown. 342 Status InitializeRemoteWorker( 343 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 344 DynamicDeviceMgr* remote_device_mgr, 345 const std::vector<string>& remote_contexts, uint64 context_id, 346 uint64 context_view_id, 347 std::function<Rendezvous*(const int64)> rendezvous_creator, 348 DistributedFunctionLibraryRuntime* cluster_flr, 349 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> 350 remote_mgr, 351 std::function<void()> resource_deallocator); 352 353 // Similar with InitializeRemoteWorker but will reuse existing context and 354 // increment context_view_id. 355 Status UpdateRemoteWorker( 356 const DeviceMgr* worker_session_device_mgr, 357 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 358 DynamicDeviceMgr* remote_device_mgr, 359 const std::vector<string>& remote_contexts, uint64 context_id, 360 DistributedFunctionLibraryRuntime* cluster_flr); 361 362 Status StoreCollectiveOpsServer( 363 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr, 364 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); 365 366 // For the specified remote worker, preprocess and set its device filters. 367 Status SetRemoteDeviceFilters(const string& remote_worker, 368 const std::vector<string>& device_filters); 369 370 // For the specified remote worker, apply the stored device filters to the 371 // list of device attributes following these rules: 372 // (1) if the remote worker does not have device filters, all devices are 373 // visible to the worker; 374 // (2) if the device is on the remote worker, then it is visible; 375 // (3) if the device matches at least one device filter, then it is visible. 376 // The result is saved as a boolean vector of the same length (i.e., 377 // filtered_device_mask) indicating whether each of the devices is visible to 378 // the remote worker. 379 void FilterDevicesForRemoteWorkers( 380 const string& remote_worker, 381 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs, 382 std::vector<bool>* filtered_device_mask); 383 384 // TODO(fishx): Remove the custom deleter once we remove forward declaration. 385 const std::unique_ptr<eager::RemoteMgr, 386 std::function<void(eager::RemoteMgr*)>>& RemoteMgr()387 RemoteMgr() { 388 return remote_mgr_; 389 } 390 391 // If true, then tensors should be shipped across processes via the 392 // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be 393 // used instead (which in-turn use WorkerService.RecvTensor RPCs). UseSendTensorRPC()394 bool UseSendTensorRPC() { return use_send_tensor_rpc_; } 395 GetServer()396 tensorflow::ServerInterface* GetServer() { return server_.get(); } 397 398 #endif // IS_MOBILE_PLATFORM 399 400 // Closes remote eager contexts, waits for all RPCs to finish, and 401 // destroys the EagerClientCache. No RPCs can be made through this context 402 // after this method has been called. 403 // This method exists to aid a clean shutdown. It causes all RPCs to finish 404 // and remote TensorHandles to release their references to this context. 405 // To avoid deadlocks, this method must not be called on the thread 406 // processing RPCs because it makes RPCs and waits for their completion. 407 // 408 // On mobile, it just cleans the caches. 409 void WaitForAndCloseRemoteContexts(); 410 PinSmallOpsToCPU()411 bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } 412 TFEnv()413 tensorflow::Env* TFEnv() const { return env_; } 414 415 std::vector<const FunctionDef*> ListRegisteredFunctions(); 416 417 Status FindDeviceFromName(const char* device_name, Device** device) const; 418 419 bool OnSameTask(const Device* first, const Device* second) const; 420 // Gets the CPU device on the task of device. 421 Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; 422 423 private: 424 void InitPrioritizedDeviceTypeList(); 425 Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); 426 Status RegisterExistingFunctionsOnRemoteWorkers( 427 const std::vector<const FunctionDef*>& function_defs, 428 const std::vector<string>& remote_workers); 429 430 void ResetPFLR(const DeviceMgr* device_mgr, Env* env, 431 const ConfigProto* config, int graph_def_version, 432 const FunctionLibraryDefinition* lib_def, 433 const OptimizerOptions& optimizer_options, 434 thread::ThreadPool* thread_pool = nullptr, 435 DistributedFunctionLibraryRuntime* cluster_flr = nullptr, 436 const CustomKernelCreator* custom_kernel_creator = nullptr); 437 438 void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr); 439 440 template <typename T> 441 struct OwnedOrUnownedHelper { 442 public: OwnedOrUnownedHelperOwnedOrUnownedHelper443 OwnedOrUnownedHelper() {} 444 explicit OwnedOrUnownedHelper(T* object, const bool owned = false) { 445 Reset(object, owned); 446 } 447 ResetOwnedOrUnownedHelper448 void Reset(std::unique_ptr<T> object) { 449 owned_object = std::move(object); 450 unowned_object_ptr = nullptr; 451 } 452 453 void Reset(T* object, const bool owned = false) { 454 if (owned) { 455 owned_object.reset(object); 456 unowned_object_ptr = nullptr; 457 } else { 458 owned_object.reset(nullptr); 459 unowned_object_ptr = object; 460 } 461 } 462 OwnedOwnedOrUnownedHelper463 bool Owned() const { return owned_object != nullptr; } 464 GetOwnedOwnedOrUnownedHelper465 T* GetOwned() const { return owned_object.get(); } GetOwnedOrUnownedHelper466 T* Get() const { 467 return owned_object ? owned_object.get() : unowned_object_ptr; 468 } 469 470 std::unique_ptr<T> owned_object = nullptr; 471 T* unowned_object_ptr = nullptr; 472 }; 473 474 const ContextDevicePlacementPolicy default_device_placement_policy_; 475 const ContextMirroringPolicy default_mirroring_policy_; 476 477 // Note: we cannot use C++11 thread_local here as there is no concept of a 478 // thread-local-object-local variable in C++11. 479 mutable mutex policy_map_mu_; 480 std::unordered_map<std::thread::id, ContextDevicePlacementPolicy> 481 device_placement_policy_ GUARDED_BY(policy_map_mu_); 482 std::unordered_map<std::thread::id, ContextMirroringPolicy> mirroring_policy_ 483 GUARDED_BY(policy_map_mu_); 484 485 OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_; 486 487 // Unowned DynamicDeviceMgr is set on remote worker to allow running 488 // multi-device function on remote worker. 489 OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_; 490 491 Device* host_cpu_device_; // Owned by device_manager 492 std::vector<DeviceType> prioritized_device_type_list_; 493 Rendezvous* rendezvous_; 494 std::function<Rendezvous*(const int64)> rendezvous_creator_; 495 Global()496 FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}}; 497 498 std::unique_ptr<thread::ThreadPool> thread_pool_; 499 500 const CustomKernelCreator* const custom_kernel_creator_; 501 502 // EagerContext owns the DistributedFunctionLibraryRuntime( 503 // EagerClusterFunctionLibraryRuntime) if using EagerService for remote 504 // function execution (lazy_copy_function_remote_inputs_=true). 505 OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_; 506 // One FunctionLibraryRuntime per device. 507 // func_libs[i] is the FunctionLibraryRuntime corresponding to 508 // session->devices[i]. 509 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 510 511 std::function<void(std::function<void()>)> runner_; 512 513 mutex cache_mu_; 514 struct RegisteredFunction : public core::RefCounted { ~RegisteredFunctionRegisteredFunction515 ~RegisteredFunction() override {} 516 517 std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys; 518 }; 519 std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>, 520 Fprint128Hasher> 521 kernel_cache_ GUARDED_BY(cache_mu_); 522 std::unordered_map<string, RegisteredFunction*> registered_functions_ 523 GUARDED_BY(cache_mu_); 524 525 // Whether we should compute RunMetadata. 526 std::atomic<bool> should_store_graphs_{false}; 527 mutex metadata_mu_; 528 RunMetadata run_metadata_ GUARDED_BY(metadata_mu_); 529 GraphCollector graph_collector_; 530 // TODO(fishx): Allow update following two bool after context creation. 531 const bool log_device_placement_; 532 const bool allow_soft_placement_; 533 534 // Information related to step containers. 535 std::atomic<int> num_active_steps_; 536 std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_); 537 538 EagerExecutor default_executor_; 539 mutable mutex executor_map_mu_; 540 // Not owned. 541 std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_ 542 GUARDED_BY(executor_map_mu_); 543 544 const bool log_memory_; 545 546 Env* const env_; 547 548 OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_; 549 550 #if !defined(IS_MOBILE_PLATFORM) 551 void CloseAndClearAllRemoteContexts(); 552 void CloseRemoteContexts(const std::vector<string>& remote_contexts, 553 uint64 context_id, uint64 context_view_id); 554 555 Status SetMasterContextState( 556 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, 557 std::shared_ptr<WorkerSession> worker_session, 558 std::unique_ptr<eager::EagerClientCache> remote_eager_workers, 559 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, 560 uint64 context_id, uint64 context_view_id, Rendezvous* r, 561 DeviceMgr* local_device_mgr, int keep_alive_secs, 562 DistributedFunctionLibraryRuntime* cluster_flr, 563 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> 564 remote_mgr); 565 566 // The server_ is not const since we release it when the context is destroyed. 567 // Therefore the server_ object is not marked as const (even though it should 568 // be). 569 std::unique_ptr<ServerInterface> server_; 570 WorkerEnv* worker_env_ = nullptr; 571 std::shared_ptr<WorkerSession> worker_session_; 572 std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; 573 574 mutex remote_state_mu_; 575 576 uint64 context_id_ GUARDED_BY(remote_state_mu_); 577 // The view id of an eager context should be set to 0 when context is created, 578 // and continously incremented when context with the same context_id gets 579 // updated. The view id should be consistent between master and workers. 580 uint64 context_view_id_ GUARDED_BY(remote_state_mu_); 581 std::vector<string> remote_contexts_; 582 583 int keep_alive_secs_ GUARDED_BY(remote_state_mu_); 584 std::atomic<int> sleep_for_secs_; 585 586 std::unique_ptr<Thread> keep_alive_thread_; 587 mutex keep_alive_thread_shutdown_mu_; 588 condition_variable keep_alive_thread_cv_; 589 bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false; 590 591 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> 592 remote_mgr_; 593 bool is_master_ GUARDED_BY(remote_state_mu_); 594 595 // Maps from a remote worker to a list of parsed device filters. 596 std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>> 597 cluster_device_filters_ GUARDED_BY(remote_state_mu_); 598 599 #endif // IS_MOBILE_PLATFORM 600 601 // For a multi device function, the target device of each input is unknown 602 // until the function is instantiated on the default function device. 603 // If false, eagerly copy all remote inputs to the default function device; 604 // if true, lazily copy remote inputs to their target devices to avoid 605 // redundant copies. 606 bool lazy_copy_function_remote_inputs_ = false; 607 bool use_send_tensor_rpc_; 608 const bool pin_small_ops_to_cpu_; 609 610 // Function that will be invoked in destructor to deallocate resources related 611 // to this context. 612 std::function<void()> resource_deallocator_ = nullptr; 613 }; 614 615 } // namespace tensorflow 616 617 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ 618