• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 // clang-format off
27 // Required for IS_MOBILE_PLATFORM
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/platform.h"
31 // clang-format on
32 
33 #include "tensorflow/core/common_runtime/device_factory.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
39 #include "tensorflow/core/example/example.pb.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #if !defined(IS_MOBILE_PLATFORM)
44 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
45 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h"
46 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
47 #include "tensorflow/core/distributed_runtime/server_lib.h"
48 #include "tensorflow/core/distributed_runtime/worker_cache.h"
49 #include "tensorflow/core/distributed_runtime/worker_env.h"
50 #endif  // !IS_MOBILE_PLATFORM
51 #include "tensorflow/core/framework/collective.h"
52 #include "tensorflow/core/framework/log_memory.h"
53 #include "tensorflow/core/framework/rendezvous.h"
54 #include "tensorflow/core/lib/core/stringpiece.h"
55 #include "tensorflow/core/lib/core/threadpool.h"
56 #include "tensorflow/core/lib/gtl/flatmap.h"
57 #include "tensorflow/core/lib/gtl/flatset.h"
58 #include "tensorflow/core/lib/gtl/inlined_vector.h"
59 #include "tensorflow/core/lib/gtl/map_util.h"
60 
61 #include "tensorflow/core/platform/fingerprint.h"
62 #include "tensorflow/core/platform/mutex.h"
63 #include "tensorflow/core/platform/thread_annotations.h"
64 #include "tensorflow/core/public/session_options.h"
65 #include "tensorflow/core/public/version.h"
66 
67 namespace tensorflow {
68 
69 namespace eager {
70 // We need this forward declaration because we have circular dependency:
71 // Context -> RemoteMgr -> TensorHandle -> Context.
72 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
73 class RemoteMgr;
74 }  // namespace eager
75 
76 // LINT.IfChange
77 // Note: Keep in sync with exported copy of enum in eager/c_api.h.
78 enum ContextDevicePlacementPolicy {
79   // Running operations with input tensors on the wrong device will fail.
80   DEVICE_PLACEMENT_EXPLICIT = 0,
81   // Copy the tensor to the right device but log a warning.
82   DEVICE_PLACEMENT_WARN = 1,
83   // Silently copy the tensor, which has a performance cost since the operation
84   // will be blocked till the copy completes. This is the default policy.
85   DEVICE_PLACEMENT_SILENT = 2,
86   // Placement policy which silently copies int32 tensors but not other dtypes.
87   DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
88 };
89 // LINT.ThenChange(//tensorflow/c/eager/c_api.h)
90 
91 // LINT.IfChange
92 // Note: Keep in sync with exported copy of enum in eager/c_api_experimental.h.
93 enum ContextMirroringPolicy {
94   // Do not maintain mirrors in a TensorHandle, instead make new TensorHandle
95   // copies with their own lifetime.
96   MIRRORING_NONE = 0,
97   // Mirroring any remote tensor handles, associating them with the lifetime of
98   // the local TensorHandle.
99   MIRRORING_ALL = 1,
100 };
101 // LINT.ThenChange(//tensorflow/c/eager/c_api_experimental.h)
102 
103 class RunMetadataListener {
104  public:
~RunMetadataListener()105   virtual ~RunMetadataListener() {}
106   virtual void BeforeClearRunMetadata() = 0;
107 };
108 
109 class EagerContext : public core::RefCounted {
110  public:
111   static const uint64 kInvalidContextId = 0;
112 
NewContextId()113   static uint64 NewContextId() {
114     uint64 context_id = random::New64();
115     while (context_id == kInvalidContextId) {
116       context_id = random::New64();
117     }
118     return context_id;
119   }
120 
121   EagerContext(const SessionOptions& opts,
122                ContextDevicePlacementPolicy default_device_placement_policy,
123                ContextMirroringPolicy default_mirroring_policy, bool async,
124                const bool lazy_copy_function_remote_inputs,
125                const DeviceMgr* device_mgr, bool device_mgr_owned,
126                Rendezvous* rendezvous,
127                const CustomKernelCreator* custom_kernel_creator,
128                DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
129 
130   ~EagerContext() override;
131 
132   // Returns the function library runtime for the given device.
func_lib(const Device * d)133   FunctionLibraryRuntime* func_lib(const Device* d) const {
134     return pflr_->GetFLR(d->name());
135   }
136 
pflr()137   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
138 
runner()139   std::function<void(std::function<void()>)>* runner() { return &runner_; }
140 
141   // Specify a executor for this thread.
142   void SetExecutorForThread(EagerExecutor* executor);
143 
prioritized_device_type_list()144   const std::vector<DeviceType>& prioritized_device_type_list() const {
145     return prioritized_device_type_list_;
146   }
147 
148   // Clear pending nodes in thread executors and kernel caches.
149   void ClearCachesAndThreadExecutors();
150   // Clear pending nodes in default executor and kernel caches.
151   void ClearCachesAndDefaultExecutor();
152 
153   // Sets the device placement policy for the current thread.
154   void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
155 
156   // Returns the device placement policy for the current thread.
157   ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
158 
159   // Select an appropriate device for an operation.
160   //
161   // Given the preferred device for the operation, and the list of devices the
162   // operation supports, finds the best suitable device for the operation in
163   // this context.
164   //
165   // The preferred device is specified as a `ParsedName` containing the elements
166   // (details) that the resulting device should match. If there are no such
167   // devices, and the context currently allows soft device placement, a suitable
168   // device not matching `preferred` will be chosen.
169   //
170   // The `dtype` parameter specifies the operation's result data type, if
171   // known. Setting it to DT_INVALID will make this method not use the data type
172   // for its decisions.
173   //
174   // The chosen device is stored in the `device` argument. The argument is not
175   // modified unless this method returns `Status::OK()`.
176   Status SelectDevice(DeviceNameUtils::ParsedName preferred,
177                       const PrioritizedDeviceTypeVector& supported,
178                       const DataType dtype, Device** device) const;
179 
180   // Sets the implicit copy policy for the current thread.
181   void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
182 
183   // Returns the implicit copy policy for the current thread.
184   ContextMirroringPolicy GetMirroringPolicy() const;
185 
186   bool MirrorTensors() const;
187 
188   bool LazyCopyFunctionRemoteInputs() const;
189 
190   bool FindFunctionByName(const string& name) const;
191 
192   Status FindFunctionOpData(const string& name,
193                             const tensorflow::OpRegistrationData** op_data);
194 
195   const FunctionDef* FindFunctionDef(const string& name);
196 
HostCPU()197   Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)198   Device* CanonicalDevice(Device* d) const {
199     return HostCPU() == d ? nullptr : d;
200   }
201 
GetGraphCollector()202   GraphCollector* GetGraphCollector() { return &graph_collector_; }
203 
204   EagerExecutor& Executor();
205 
206   // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
207   // entry to the KernelAndDevice cache for it if it's not exist.
208   Status AddFunctionDef(const FunctionDef& fdef);
209   // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
210   // it to the local FunctionLibraryDefinition as well, but no need to add it
211   // to the KernelAndDevice cache since they won't be executed as
212   // KernelAndDevices.
213   Status AddFunctionDef(const FunctionDef& fdef,
214                         const FunctionDefLibrary& library,
215                         const bool add_to_local_only = false);
216 
217   Status RemoveFunction(const string& func);
218 
219   core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
220 
221   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
222 
LogDevicePlacement()223   bool LogDevicePlacement() const { return log_device_placement_; }
AllowSoftPlacement()224   bool AllowSoftPlacement() const { return allow_soft_placement_; }
LogMemory()225   bool LogMemory() const { return log_memory_; }
226 
GetRendezvous()227   Rendezvous* GetRendezvous() const { return rendezvous_; }
CreateRendezvous(const int64 step_id)228   Rendezvous* CreateRendezvous(const int64 step_id) const {
229     if (rendezvous_creator_ != nullptr) {
230       return rendezvous_creator_(step_id);
231     }
232 
233 #if !defined(IS_MOBILE_PLATFORM)
234     if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
235       auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
236       remote_r->Initialize(worker_session_.get()).IgnoreError();
237       return remote_r;
238     }
239 #endif
240 
241     if (remote_device_mgr() == nullptr) {
242       return new IntraProcessRendezvous(local_device_mgr());
243     }
244 
245     return nullptr;
246   }
247 
collective_executor_mgr()248   CollectiveExecutorMgrInterface* collective_executor_mgr() {
249     return collective_executor_mgr_.Get();
250   }
GetCollectiveExecutorHandle()251   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
252     return std::unique_ptr<CollectiveExecutor::Handle>(
253         new CollectiveExecutor::Handle(
254             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
255   }
256 
local_device_mgr()257   const tensorflow::DeviceMgr* local_device_mgr() const {
258     return local_device_manager_.Get();
259   }
remote_device_mgr()260   const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
261     return remote_device_manager_.Get();
262   }
263 
GetOwnedRemoteDeviceMgr()264   tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
265     return remote_device_manager_.GetOwned();
266   }
267 
268   // TODO(apassos) clean up RunMetadata storage.
MetadataMu()269   mutex* MetadataMu() LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
270   bool ShouldStoreGraphs() LOCKS_EXCLUDED(metadata_mu_);
271   void SetShouldStoreGraphs(bool value);
RunMetadataProto()272   RunMetadata* RunMetadataProto() { return &run_metadata_; }
273   void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
274 
275   void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices);
276 
277   void StartStep();
278   void EndStep();
279   ScopedStepContainer* StepContainer();
280 
FuncLibDef()281   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
282 
283 #if !defined(IS_MOBILE_PLATFORM)
284   // Assign the EagerClient pointer to `client` based on the given device / task
285   // name, and increment the refcount of the client. The reference ownership is
286   // transferred to the caller, and the unref should automatically happen when
287   // destructing the RefCountPtr object at the caller's side.
288   // `client` must not be initialized or holding a reference of another object
289   // before calling this method.
290   Status GetClient(Device* device,
291                    core::RefCountPtr<eager::EagerClient>* client);
292   Status GetClient(const DeviceNameUtils::ParsedName& device_name,
293                    core::RefCountPtr<eager::EagerClient>* client);
294   Status GetClient(const string& remote_task,
295                    core::RefCountPtr<eager::EagerClient>* client);
296 
297   uint64 GetContextId();
298   uint64 GetContextViewId();
299   void IncrementContextViewId();
300 
301   // TODO(nareshmodi): Encapsulate remote state into a separate
302   // class/struct.
303   //
304   // Enables the eager context to communicate with remote devices. When
305   // initializing with this method, this context will be the master context,
306   // which will kill all its slaves in shutdown.
307   //
308   // - server: A ServerInterface that exports the tensorflow.WorkerService.
309   // Note that this class expects the server to already have been started.
310   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
311   // communicate with remote eager services.
312   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
313   // (should contain no local devices).
314   // - remote_contexts: A vector containing task names.
315   Status InitializeRemoteMaster(
316       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
317       std::shared_ptr<WorkerSession> worker_session,
318       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
319       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
320       const std::vector<string>& remote_contexts, uint64 context_id,
321       Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
322       DistributedFunctionLibraryRuntime* cluster_flr,
323       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
324           remote_mgr);
325 
326   // Update an existing master context with a new set of remote workers (i.e., a
327   // new "view" of cluster membership. Similar to InitializeRemoteMaster but
328   // this will keep the current context_id and increment a context_view_id, will
329   // keep the current resource manager so that resources from the previous view
330   // can still be accessed, and will automatically register existing functions
331   // if there are newly added hosts.
332   Status UpdateRemoteMaster(
333       WorkerEnv* worker_env,
334       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
335       const std::vector<string>& add_remote_contexts,
336       const std::vector<string>& remove_remote_contexts, uint64 context_id,
337       Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
338       DistributedFunctionLibraryRuntime* cluster_flr);
339 
340   // Similar with InitializeRemoteMaster but this context will not kill remote
341   // contexts in shutdown.
342   Status InitializeRemoteWorker(
343       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
344       DynamicDeviceMgr* remote_device_mgr,
345       const std::vector<string>& remote_contexts, uint64 context_id,
346       uint64 context_view_id,
347       std::function<Rendezvous*(const int64)> rendezvous_creator,
348       DistributedFunctionLibraryRuntime* cluster_flr,
349       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
350           remote_mgr,
351       std::function<void()> resource_deallocator);
352 
353   // Similar with InitializeRemoteWorker but will reuse existing context and
354   // increment context_view_id.
355   Status UpdateRemoteWorker(
356       const DeviceMgr* worker_session_device_mgr,
357       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
358       DynamicDeviceMgr* remote_device_mgr,
359       const std::vector<string>& remote_contexts, uint64 context_id,
360       DistributedFunctionLibraryRuntime* cluster_flr);
361 
362   Status StoreCollectiveOpsServer(
363       std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
364       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
365 
366   // For the specified remote worker, preprocess and set its device filters.
367   Status SetRemoteDeviceFilters(const string& remote_worker,
368                                 const std::vector<string>& device_filters);
369 
370   // For the specified remote worker, apply the stored device filters to the
371   // list of device attributes following these rules:
372   // (1) if the remote worker does not have device filters, all devices are
373   //     visible to the worker;
374   // (2) if the device is on the remote worker, then it is visible;
375   // (3) if the device matches at least one device filter, then it is visible.
376   // The result is saved as a boolean vector of the same length (i.e.,
377   // filtered_device_mask) indicating whether each of the devices is visible to
378   // the remote worker.
379   void FilterDevicesForRemoteWorkers(
380       const string& remote_worker,
381       const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
382       std::vector<bool>* filtered_device_mask);
383 
384   // TODO(fishx): Remove the custom deleter once we remove forward declaration.
385   const std::unique_ptr<eager::RemoteMgr,
386                         std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()387   RemoteMgr() {
388     return remote_mgr_;
389   }
390 
391   // If true, then tensors should be shipped across processes via the
392   // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
393   // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()394   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
395 
GetServer()396   tensorflow::ServerInterface* GetServer() { return server_.get(); }
397 
398 #endif  // IS_MOBILE_PLATFORM
399 
400   // Closes remote eager contexts, waits for all RPCs to finish, and
401   // destroys the EagerClientCache. No RPCs can be made through this context
402   // after this method has been called.
403   // This method exists to aid a clean shutdown. It causes all RPCs to finish
404   // and remote TensorHandles to release their references to this context.
405   // To avoid deadlocks, this method must not be called on the thread
406   // processing RPCs because it makes RPCs and waits for their completion.
407   //
408   // On mobile, it just cleans the caches.
409   void WaitForAndCloseRemoteContexts();
410 
PinSmallOpsToCPU()411   bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
412 
TFEnv()413   tensorflow::Env* TFEnv() const { return env_; }
414 
415   std::vector<const FunctionDef*> ListRegisteredFunctions();
416 
417   Status FindDeviceFromName(const char* device_name, Device** device) const;
418 
419   bool OnSameTask(const Device* first, const Device* second) const;
420   // Gets the CPU device on the task of device.
421   Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
422 
423  private:
424   void InitPrioritizedDeviceTypeList();
425   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
426   Status RegisterExistingFunctionsOnRemoteWorkers(
427       const std::vector<const FunctionDef*>& function_defs,
428       const std::vector<string>& remote_workers);
429 
430   void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
431                  const ConfigProto* config, int graph_def_version,
432                  const FunctionLibraryDefinition* lib_def,
433                  const OptimizerOptions& optimizer_options,
434                  thread::ThreadPool* thread_pool = nullptr,
435                  DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
436                  const CustomKernelCreator* custom_kernel_creator = nullptr);
437 
438   void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
439 
440   template <typename T>
441   struct OwnedOrUnownedHelper {
442    public:
OwnedOrUnownedHelperOwnedOrUnownedHelper443     OwnedOrUnownedHelper() {}
444     explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
445       Reset(object, owned);
446     }
447 
ResetOwnedOrUnownedHelper448     void Reset(std::unique_ptr<T> object) {
449       owned_object = std::move(object);
450       unowned_object_ptr = nullptr;
451     }
452 
453     void Reset(T* object, const bool owned = false) {
454       if (owned) {
455         owned_object.reset(object);
456         unowned_object_ptr = nullptr;
457       } else {
458         owned_object.reset(nullptr);
459         unowned_object_ptr = object;
460       }
461     }
462 
OwnedOwnedOrUnownedHelper463     bool Owned() const { return owned_object != nullptr; }
464 
GetOwnedOwnedOrUnownedHelper465     T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper466     T* Get() const {
467       return owned_object ? owned_object.get() : unowned_object_ptr;
468     }
469 
470     std::unique_ptr<T> owned_object = nullptr;
471     T* unowned_object_ptr = nullptr;
472   };
473 
474   const ContextDevicePlacementPolicy default_device_placement_policy_;
475   const ContextMirroringPolicy default_mirroring_policy_;
476 
477   // Note: we cannot use C++11 thread_local here as there is no concept of a
478   // thread-local-object-local variable in C++11.
479   mutable mutex policy_map_mu_;
480   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
481       device_placement_policy_ GUARDED_BY(policy_map_mu_);
482   std::unordered_map<std::thread::id, ContextMirroringPolicy> mirroring_policy_
483       GUARDED_BY(policy_map_mu_);
484 
485   OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_;
486 
487   // Unowned DynamicDeviceMgr is set on remote worker to allow running
488   // multi-device function on remote worker.
489   OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
490 
491   Device* host_cpu_device_;  // Owned by device_manager
492   std::vector<DeviceType> prioritized_device_type_list_;
493   Rendezvous* rendezvous_;
494   std::function<Rendezvous*(const int64)> rendezvous_creator_;
495 
Global()496   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
497 
498   std::unique_ptr<thread::ThreadPool> thread_pool_;
499 
500   const CustomKernelCreator* const custom_kernel_creator_;
501 
502   // EagerContext owns the DistributedFunctionLibraryRuntime(
503   // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
504   // function execution (lazy_copy_function_remote_inputs_=true).
505   OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
506   // One FunctionLibraryRuntime per device.
507   // func_libs[i] is the FunctionLibraryRuntime corresponding to
508   // session->devices[i].
509   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
510 
511   std::function<void(std::function<void()>)> runner_;
512 
513   mutex cache_mu_;
514   struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction515     ~RegisteredFunction() override {}
516 
517     std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
518   };
519   std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
520                      Fprint128Hasher>
521       kernel_cache_ GUARDED_BY(cache_mu_);
522   std::unordered_map<string, RegisteredFunction*> registered_functions_
523       GUARDED_BY(cache_mu_);
524 
525   // Whether we should compute RunMetadata.
526   std::atomic<bool> should_store_graphs_{false};
527   mutex metadata_mu_;
528   RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
529   GraphCollector graph_collector_;
530   // TODO(fishx): Allow update following two bool after context creation.
531   const bool log_device_placement_;
532   const bool allow_soft_placement_;
533 
534   // Information related to step containers.
535   std::atomic<int> num_active_steps_;
536   std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_);
537 
538   EagerExecutor default_executor_;
539   mutable mutex executor_map_mu_;
540   // Not owned.
541   std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
542       GUARDED_BY(executor_map_mu_);
543 
544   const bool log_memory_;
545 
546   Env* const env_;
547 
548   OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
549 
550 #if !defined(IS_MOBILE_PLATFORM)
551   void CloseAndClearAllRemoteContexts();
552   void CloseRemoteContexts(const std::vector<string>& remote_contexts,
553                            uint64 context_id, uint64 context_view_id);
554 
555   Status SetMasterContextState(
556       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
557       std::shared_ptr<WorkerSession> worker_session,
558       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
559       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
560       uint64 context_id, uint64 context_view_id, Rendezvous* r,
561       DeviceMgr* local_device_mgr, int keep_alive_secs,
562       DistributedFunctionLibraryRuntime* cluster_flr,
563       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
564           remote_mgr);
565 
566   // The server_ is not const since we release it when the context is destroyed.
567   // Therefore the server_ object is not marked as const (even though it should
568   // be).
569   std::unique_ptr<ServerInterface> server_;
570   WorkerEnv* worker_env_ = nullptr;
571   std::shared_ptr<WorkerSession> worker_session_;
572   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
573 
574   mutex remote_state_mu_;
575 
576   uint64 context_id_ GUARDED_BY(remote_state_mu_);
577   // The view id of an eager context should be set to 0 when context is created,
578   // and continously incremented when context with the same context_id gets
579   // updated. The view id should be consistent between master and workers.
580   uint64 context_view_id_ GUARDED_BY(remote_state_mu_);
581   std::vector<string> remote_contexts_;
582 
583   int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
584   std::atomic<int> sleep_for_secs_;
585 
586   std::unique_ptr<Thread> keep_alive_thread_;
587   mutex keep_alive_thread_shutdown_mu_;
588   condition_variable keep_alive_thread_cv_;
589   bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
590 
591   std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
592       remote_mgr_;
593   bool is_master_ GUARDED_BY(remote_state_mu_);
594 
595   // Maps from a remote worker to a list of parsed device filters.
596   std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
597       cluster_device_filters_ GUARDED_BY(remote_state_mu_);
598 
599 #endif  // IS_MOBILE_PLATFORM
600 
601   // For a multi device function, the target device of each input is unknown
602   // until the function is instantiated on the default function device.
603   // If false, eagerly copy all remote inputs to the default function device;
604   // if true, lazily copy remote inputs to their target devices to avoid
605   // redundant copies.
606   bool lazy_copy_function_remote_inputs_ = false;
607   bool use_send_tensor_rpc_;
608   const bool pin_small_ops_to_cpu_;
609 
610   // Function that will be invoked in destructor to deallocate resources related
611   // to this context.
612   std::function<void()> resource_deallocator_ = nullptr;
613 };
614 
615 }  // namespace tensorflow
616 
617 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
618