• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/session_options.h"
61 #include "tensorflow/core/public/version.h"
62 #include "tensorflow/core/util/device_name_utils.h"
63 
64 // "tensorflow/core/platform/platform.h" must be included first before using
65 // IS_MOBILE_PLATFORM.
66 #if !defined(IS_MOBILE_PLATFORM)
67 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
68 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
69 #include "tensorflow/core/distributed_runtime/server_lib.h"
70 #include "tensorflow/core/distributed_runtime/worker_cache.h"
71 #include "tensorflow/core/distributed_runtime/worker_env.h"
72 #endif  // !IS_MOBILE_PLATFORM
73 
74 namespace tensorflow {
75 
76 namespace eager {
77 // We need this forward declaration because we have circular dependency:
78 // Context -> RemoteMgr -> TensorHandle -> Context.
79 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
80 class RemoteMgr;
81 }  // namespace eager
82 
83 class TensorHandle;
84 class EagerOperation;
85 
86 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
87  public:
88   static constexpr uint64 kInvalidContextId = 0;
89 
NewContextId()90   static uint64 NewContextId() {
91     uint64 context_id = random::New64();
92     while (context_id == kInvalidContextId) {
93       context_id = random::New64();
94     }
95     return context_id;
96   }
97 
98   EagerContext(
99       const SessionOptions& opts,
100       ContextDevicePlacementPolicy default_device_placement_policy, bool async,
101       /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned,
102       /*const*/ Rendezvous* rendezvous,
103       DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
104       CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
105       bool run_eager_op_as_function = false);
106 
Release()107   void Release() override { Unref(); }
108 
109   AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
110   AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
111   AbstractTensorInterface* CreateInt32Scalar(int32_t value) override;
112   AbstractTensorInterface* CreateFloatScalar(float value) override;
113   AbstractTensorInterface* CreateDoubleScalar(double value) override;
114   AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
115   AbstractTensorInterface* CreateStringScalar(
116       tensorflow::tstring value) override;
117   AbstractTensorInterface* CreateComplex128Scalar(
118       tensorflow::complex128 value) override;
119   AbstractTensorInterface* CreateBoolScalar(bool value) override;
120 
121   AbstractTensorInterface* CreateTensor(
122       DataType dtype, absl::Span<const int64> dim_sizes) override;
123   AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
124                                         int num_dims, void* data, size_t len,
125                                         MemoryReleaser memory_releaser,
126                                         void* memory_releaser_arg) override;
127 
128   ImmediateExecutionTensorHandle* CreateLocalHandle(
129       AbstractTensorInterface* t) override;
130   // Create an abstract tensor handle from tensorflow::Tensor.
131   ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
132       tensorflow::Tensor& t, const char* d_name) override;
133   ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
134       ImmediateExecutionTensorHandle* handle, const char* device_name,
135       Status* status) override;
136   ImmediateExecutionOperation* CreateOperation() override;
137 
138   // This is a virtual helper function to convert TFRT TensorHandle to
139   // tensorflow::TensorHandle. In current runtime EagerContext, just forward
140   // the input since the input tensor handle is already a
141   // tensorflow::TensorHandle.
142   ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
143       ImmediateExecutionTensorHandle* handle) override;
144 
145   Status RegisterFunction(AbstractFunction* f) override;
146 
147   bool UsesTFRT() override;
148 
149   bool RunEagerOpAsFunction() const;
150 
151   void ListDevices(std::vector<DeviceAttributes>* devices) override;
152 
153   Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override;
154 
155   // Returns the function library runtime for the given device.
func_lib(const Device * d)156   FunctionLibraryRuntime* func_lib(const Device* d) const {
157     return pflr_->GetFLR(d->name());
158   }
159 
pflr()160   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
161 
runner()162   std::function<void(std::function<void()>)>* runner() { return &runner_; }
163 
164   // Specify a executor for this thread.
165   void SetExecutorForThread(EagerExecutor* executor) override;
166 
prioritized_device_type_list()167   const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
168       const {
169     mutex_lock l(device_type_list_mu_);
170     return prioritized_device_type_list_;
171   }
172 
173   // Clear pending nodes in thread executors and kernel caches.
174   void ClearCachesAndThreadExecutors() override;
175   // Clear pending nodes in default executor and kernel caches.
176   void ClearCachesAndDefaultExecutor();
177 
178   // Sets the device placement policy for the current thread.
179   void SetThreadLocalDevicePlacementPolicy(
180       ContextDevicePlacementPolicy policy) override;
181 
182   // Returns the device placement policy for the current thread.
183   ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
184 
185   // Select an appropriate device for an operation.
186   //
187   // Given the preferred device for the operation, and the node_def, finds the
188   // best suitable device for the operation in this context.
189   //
190   // The preferred device is specified as a `ParsedName` containing the elements
191   // (details) that the resulting device should match. If there are no such
192   // devices, and the context currently allows soft device placement, a suitable
193   // device not matching `preferred` will be chosen.
194   //
195   // The chosen device is stored in the `device` argument. The argument is not
196   // modified unless this method returns `Status::OK()`.
197   Status SelectDevice(DeviceNameUtils::ParsedName preferred,
198                       const NodeDef& ndef, Device** out) const;
199 
200   bool FindFunctionByName(const string& name) const;
201 
202   Status FindFunctionOpData(const string& name,
203                             const tensorflow::OpRegistrationData** op_data);
204 
205   const FunctionDef* FindFunctionDef(const string& name) const override;
206 
HostCPU()207   Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)208   Device* CanonicalDevice(Device* d) const {
209     return HostCPU() == d ? nullptr : d;
210   }
HostCPUParsedName()211   const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
212     return HostCPU()->parsed_name();
213   }
214 
HostCPUName()215   const string& HostCPUName() const override { return HostCPU()->name(); }
216 
GetGraphCollector()217   GraphCollector* GetGraphCollector() { return &graph_collector_; }
218 
219   EagerExecutor& Executor() override;
220 
221   // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
222   // entry to the KernelAndDevice cache for it if it's not exist.
223   Status AddFunctionDef(const FunctionDef& fdef) override;
224 
225   Status AddFunctionDefWithStackTraces(
226       const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
227 
228   // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
229   // it to the local FunctionLibraryDefinition as well, but no need to add it
230   // to the KernelAndDevice cache since they won't be executed as
231   // KernelAndDevices.
232   Status AddFunctionDef(const FunctionDef& fdef,
233                         const FunctionDefLibrary& library,
234                         bool add_to_local_only = false,
235                         const StackTracesMap& stack_traces = {});
236 
237   const FunctionDef* GetFunctionDef(const string& function_name);
238 
239   std::vector<string> ListFunctionNames() override;
240 
241   Status RemoveFunction(const string& func) override;
242 
243   // Wait for pending nodes to be finished in local executors (including context
244   // default executor and thread executors) and executors on remote workers.
245   // Return combined status of remote executors. If there are multiple errors,
246   // the Status code will be the same as the first remote executor that has
247   // errors, and the error message will be combined from all executors.
248   Status SyncExecutors();
249 
AsyncWait()250   Status AsyncWait() override { return SyncExecutors(); }
251 
252   core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
253 
254   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
255 
LogDevicePlacement()256   bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)257   void SetLogDevicePlacement(bool enable) override {
258     log_device_placement_ = enable;
259   }
260 
261   // When tensor transfer across functions/eager executions using send/recv ops
262   // are required, `reuse_rendezvous_for_functions_` can be set to true so that
263   // function executions and eager executions use the same rendezvous instance,
264   // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)265   void SetReuseRendezvousForFunctions(
266       bool reuse_rendezvous_for_functions) override {
267     reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
268   }
GetReuseRendezvousForFunctions()269   bool GetReuseRendezvousForFunctions() const {
270     return reuse_rendezvous_for_functions_;
271   }
272 
AllowSoftPlacement()273   bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)274   void SetAllowSoftPlacement(bool enable) override {
275     allow_soft_placement_ = enable;
276   }
LogMemory()277   bool LogMemory() const { return log_memory_; }
278 
GetRendezvous()279   Rendezvous* GetRendezvous() const { return rendezvous_; }
280 
ResetGlobalRendezvousForFunction()281   void ResetGlobalRendezvousForFunction() override {
282     mutex_lock l(global_rendezvous_mu_);
283     global_rendezvous_for_functions_ =
284         core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
285   }
286 
287   // Returns a function which maps from step_id to rendezvous. This closure
288   // respects the value of `SetReuseRendezvousForFunctions` at the time the
289   // closure was created, which allows the setting to be toggled around async op
290   // launches.
291   //
292   // The caller of the returned function owns a reference to the resulting
293   // Rendezvous.
RendezvousCreator()294   std::function<Rendezvous*(int64_t)> RendezvousCreator() {
295     if (reuse_rendezvous_for_functions_) {
296       return [this](int64_t step_id) {
297         mutex_lock l(global_rendezvous_mu_);
298         global_rendezvous_for_functions_->Ref();
299         return global_rendezvous_for_functions_.get();
300       };
301     } else {
302       return [this](int64_t step_id) { return CreateRendezvous(step_id); };
303     }
304   }
305 
collective_executor_mgr()306   CollectiveExecutorMgrInterface* collective_executor_mgr() {
307     return collective_executor_mgr_.Get();
308   }
GetCollectiveExecutorHandle()309   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
310     return std::unique_ptr<CollectiveExecutor::Handle>(
311         new CollectiveExecutor::Handle(
312             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
313   }
314 
local_device_mgr()315   tensorflow::DeviceMgr* local_device_mgr() const {
316     return local_device_manager_.Get();
317   }
remote_device_mgr()318   const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
319     return remote_device_manager_.Get();
320   }
321 
GetOwnedRemoteDeviceMgr()322   tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
323     return remote_device_manager_.GetOwned();
324   }
325 
ListLocalTfDevices()326   std::vector<Device*> ListLocalTfDevices() override {
327     return local_device_mgr()->ListDevices();
328   }
329 
330   // TODO(apassos) clean up RunMetadata storage.
MetadataMu()331   mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
332   bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
333   void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()334   RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
335     return run_metadata_.get();
336   }
337   std::unique_ptr<RunMetadata> ExportRunMetadata() override
338       TF_LOCKS_EXCLUDED(metadata_mu_);
339 
340   void StartStep() override;
341   void EndStep() override;
342   ScopedStepContainer* StepContainer();
343 
FuncLibDef()344   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
345 
346 #if !defined(IS_MOBILE_PLATFORM)
347   // Assign the EagerClient pointer to `client` based on the given device / task
348   // name, and increment the refcount of the client. The reference ownership is
349   // transferred to the caller, and the unref should automatically happen when
350   // destructing the RefCountPtr object at the caller's side.
351   // `client` must not be initialized or holding a reference of another object
352   // before calling this method.
353   Status GetClient(Device* device,
354                    core::RefCountPtr<eager::EagerClient>* client);
355   Status GetClient(const DeviceNameUtils::ParsedName& device_name,
356                    core::RefCountPtr<eager::EagerClient>* client);
357   Status GetClient(const string& remote_task,
358                    core::RefCountPtr<eager::EagerClient>* client);
359 
360   uint64 GetContextId() const;
361   uint64 GetContextViewId() const;
362   void IncrementContextViewId();
363 
364   Status EnableCollectiveOps(const ServerDef& server_def) override;
365 
366   // TODO(nareshmodi): Encapsulate remote state into a separate
367   // class/struct.
368   //
369   // Enables the eager context to communicate with remote devices. When
370   // initializing with this method, this context will be the primary context,
371   // which will kill all its remote contexts in shutdown.
372   //
373   // - server: A ServerInterface that exports the tensorflow.WorkerService.
374   // Note that this class expects the server to already have been started.
375   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
376   // communicate with remote eager services.
377   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
378   // (should contain no local devices).
379   // - remote_contexts: A vector containing task names.
380   // TODO(b/184375824): clean up parameter order for better readability.
381   Status InitializeRemoteMaster(
382       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
383       std::shared_ptr<WorkerSession> worker_session,
384       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
385       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
386       const std::vector<string>& remote_contexts, uint64 context_id,
387       /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr,
388       int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
389       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
390           remote_mgr);
391 
392   // Update an existing master context with a new set of remote workers (i.e., a
393   // new "view" of cluster membership. Similar to InitializeRemoteMaster but
394   // this will keep the current context_id and increment a context_view_id, will
395   // keep the current resource manager so that resources from the previous view
396   // can still be accessed, and will automatically register existing functions
397   // if there are newly added hosts.
398   Status UpdateRemoteMaster(
399       uint64 context_id,
400       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
401       const std::vector<string>& add_remote_contexts,
402       const std::vector<string>& remove_remote_contexts);
403 
404   // Similar with InitializeRemoteMaster but this context will not kill remote
405   // contexts in shutdown.
406   Status InitializeRemoteWorker(
407       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
408       DynamicDeviceMgr* remote_device_mgr,
409       const std::vector<string>& remote_contexts, uint64 context_id,
410       uint64 context_view_id,
411       std::function<Rendezvous*(const int64_t)> rendezvous_creator,
412       DistributedFunctionLibraryRuntime* cluster_flr,
413       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
414           remote_mgr,
415       std::function<void()> resource_deallocator);
416 
417   // Similar with InitializeRemoteWorker but will reuse existing context and
418   // increment context_view_id.
419   Status UpdateRemoteWorker(
420       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
421       const std::vector<string>& remote_contexts, uint64 context_id);
422 
423   Status StoreCollectiveOpsServer(
424       std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
425       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
426 
427   // For the specified remote worker, preprocess and set its device filters.
428   Status SetRemoteDeviceFilters(const string& remote_worker,
429                                 const std::vector<string>& device_filters);
430 
431   // For the specified remote worker, apply the stored device filters to the
432   // list of device attributes following these rules:
433   // (1) if the remote worker does not have device filters, all devices are
434   //     visible to the worker;
435   // (2) if the device is on the remote worker, then it is visible;
436   // (3) if the device matches at least one device filter, then it is visible.
437   // The result is saved as a boolean vector of the same length (i.e.,
438   // filtered_device_mask) indicating whether each of the devices is visible to
439   // the remote worker.
440   void FilterDevicesForRemoteWorkers(
441       const string& remote_worker,
442       const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
443       std::vector<bool>* filtered_device_mask);
444 
445   // TODO(fishx): Remove the custom deleter once we remove forward declaration.
446   const std::unique_ptr<eager::RemoteMgr,
447                         std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()448   RemoteMgr() {
449     return remote_mgr_;
450   }
451 
452   // If true, then tensors should be shipped across processes via the
453   // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
454   // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()455   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
456 
GetServer()457   tensorflow::ServerInterface* GetServer() { return server_.get(); }
458 
459   // For LLVM style RTTI.
classof(const AbstractContext * ptr)460   static bool classof(const AbstractContext* ptr) {
461     return ptr->getKind() == kEager;
462   }
463 
464   // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)465   void SetDistributedManager(
466       std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
467       override {
468     distributed_manager_ = std::move(distributed);
469   }
GetDistributedManager()470   ImmediateExecutionDistributedManager* GetDistributedManager() override {
471     return distributed_manager_.get();
472   }
473 
474   // May only be used during multi-client setup so that a RemoteRendezvous
475   // can be initialized instead of defaulting to the IntraProcessRendezvous.
476   void SetWorkerEnv(WorkerEnv* worker_env,
477                     std::shared_ptr<WorkerSession> worker_session);
478 #endif  // IS_MOBILE_PLATFORM
479 
480   // Closes remote eager contexts, waits for all RPCs to finish, and
481   // destroys the EagerClientCache. No RPCs can be made through this context
482   // after this method has been called.
483   // This method exists to aid a clean shutdown. It causes all RPCs to finish
484   // and remote TensorHandles to release their references to this context.
485   // To avoid deadlocks, this method must not be called on the thread
486   // processing RPCs because it makes RPCs and waits for their completion.
487   //
488   // On mobile, it just cleans the caches.
489   void WaitForAndCloseRemoteContexts();
490 
PinSmallOpsToCPU()491   bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
492 
TFEnv()493   tensorflow::Env* TFEnv() const { return env_; }
494 
495   Status FindDeviceFromName(const char* device_name, Device** device) const;
496 
497   Status FindCompositeDeviceFromName(StringPiece device_name,
498                                      CompositeDevice** device) const;
499 
500   Status RegisterCustomDevice(const string& name,
501                               std::unique_ptr<CustomDevice> device) override;
502 
GetCustomDeviceOpHandler()503   CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
504     return custom_device_op_handler_;
505   };
506 
507   // Find or create a composite device with the given `underlying_devices` and
508   // `device_name` (if not empty).
509   Status FindOrCreateCompositeDevice(
510       const std::vector<string>& underlying_devices, const string& device_name,
511       CompositeDevice** composite_device);
512 
513   bool OnSameTask(const Device* first, const Device* second) const;
514   // Gets the CPU device on the task of device.
515   Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
516 
session_options()517   const SessionOptions& session_options() const { return opts_; }
518   void InitPrioritizedDeviceTypeList();
519 
520  private:
CreateRendezvous(int64_t step_id)521   Rendezvous* CreateRendezvous(int64_t step_id) const {
522     if (rendezvous_creator_ != nullptr) {
523       return rendezvous_creator_(step_id);
524     }
525 
526 #if !defined(IS_MOBILE_PLATFORM)
527     if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
528       auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
529       remote_r->Initialize(worker_session_.get()).IgnoreError();
530       return remote_r;
531     }
532 #endif
533 
534     if (remote_device_mgr() == nullptr) {
535       return new IntraProcessRendezvous(local_device_mgr());
536     }
537 
538     return nullptr;
539   }
540 
541   ~EagerContext() override;
542 
543   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
544   Status RegisterExistingFunctionsOnRemoteWorkers(
545       const std::vector<string>& remote_workers);
546 
547   void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
548                  const ConfigProto* config, int graph_def_version,
549                  const FunctionLibraryDefinition* lib_def,
550                  const OptimizerOptions& optimizer_options,
551                  thread::ThreadPool* thread_pool = nullptr,
552                  DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
553 
554   void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
555 
556   void ClearResourceContainer(const string& name);
557 
558   template <typename T>
559   struct OwnedOrUnownedHelper {
560    public:
OwnedOrUnownedHelperOwnedOrUnownedHelper561     OwnedOrUnownedHelper() {}
562     explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
563       Reset(object, owned);
564     }
565 
ResetOwnedOrUnownedHelper566     void Reset(std::unique_ptr<T> object) {
567       owned_object = std::move(object);
568       unowned_object_ptr = nullptr;
569     }
570 
571     void Reset(T* object, const bool owned = false) {
572       if (owned) {
573         owned_object.reset(object);
574         unowned_object_ptr = nullptr;
575       } else {
576         owned_object.reset(nullptr);
577         unowned_object_ptr = object;
578       }
579     }
580 
OwnedOwnedOrUnownedHelper581     bool Owned() const { return owned_object != nullptr; }
582 
GetOwnedOwnedOrUnownedHelper583     T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper584     T* Get() const {
585       return owned_object ? owned_object.get() : unowned_object_ptr;
586     }
587 
588     std::unique_ptr<T> owned_object = nullptr;
589     T* unowned_object_ptr = nullptr;
590   };
591 
592   SessionOptions opts_;
593   const ContextDevicePlacementPolicy default_device_placement_policy_;
594 
595   // Note: we cannot use C++11 thread_local here as there is no concept of a
596   // thread-local-object-local variable in C++11.
597   mutable mutex policy_map_mu_;
598   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
599       device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
600 
601   OwnedOrUnownedHelper<DeviceMgr> local_device_manager_;
602   // Maintain copy of all previously created local device managers.
603   std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_;
604 
605   // Unowned DynamicDeviceMgr is set on remote worker to allow running
606   // multi-device function on remote worker.
607   OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
608 
609   Device* host_cpu_device_;  // Owned by device_manager
610   mutable mutex device_type_list_mu_;
611   std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
612       TF_GUARDED_BY(device_type_list_mu_);
613   Rendezvous* rendezvous_;
614   std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
615   CustomDeviceOpHandler custom_device_op_handler_;
616 
617   mutable mutex composite_devices_mu_;
618   // Maps from the fingerprint of a set of device names to a virtual
619   // CompositeDevice.
620   // TODO(b/145922293): Consider taking device names as keys.
621   absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
622       composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
623 
Global()624   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
625 
626   std::unique_ptr<thread::ThreadPool> thread_pool_;
627 
628   // EagerContext owns the DistributedFunctionLibraryRuntime(
629   // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
630   // function execution (lazy_copy_function_remote_inputs_=true).
631   OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
632   // One FunctionLibraryRuntime per device.
633   // func_libs[i] is the FunctionLibraryRuntime corresponding to
634   // session->devices[i].
635   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
636 
637   std::function<void(std::function<void()>)> runner_;
638 
639   mutex cache_mu_;
640   struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction641     ~RegisteredFunction() override {}
642 
643     std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
644   };
645   std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
646                      Fprint128Hasher>
647       kernel_cache_ TF_GUARDED_BY(cache_mu_);
648   std::unordered_map<string, RegisteredFunction*> registered_functions_
649       TF_GUARDED_BY(cache_mu_);
650 
651   // Whether we should compute RunMetadata.
652   std::atomic<bool> should_store_graphs_{false};
653   mutex metadata_mu_;
654   std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
655   GraphCollector graph_collector_;
656   std::atomic<bool> log_device_placement_;
657   std::atomic<bool> allow_soft_placement_;
658 
659   // Information related to step containers.
660   std::atomic<int> num_active_steps_;
661   std::unique_ptr<ScopedStepContainer> step_container_
662       TF_GUARDED_BY(metadata_mu_);
663 
664   EagerExecutor default_executor_;
665   mutable mutex executor_map_mu_;
666   // Not owned.
667   std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
668       TF_GUARDED_BY(executor_map_mu_);
669   std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
670       has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
671 
672   const bool log_memory_;
673 
674   // Whether to use same rendezvous instance across function/eager executions.
675   bool reuse_rendezvous_for_functions_ = false;
676   mutable mutex global_rendezvous_mu_;
677   core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_
678       TF_GUARDED_BY(global_rendezvous_mu_);
679 
680   Env* const env_;
681 
682   OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
683 
684 #if !defined(IS_MOBILE_PLATFORM)
685   std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
686   bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
687   void CloseAndClearAllRemoteContexts();
688   void CloseRemoteContexts(const std::vector<string>& remote_contexts,
689                            uint64 context_id, uint64 context_view_id);
690 
691   // TODO(b/184375824): clean up parameter order for better readability.
692   Status SetMasterContextState(
693       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
694       std::shared_ptr<WorkerSession> worker_session,
695       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
696       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
697       uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r,
698       /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs,
699       DistributedFunctionLibraryRuntime* cluster_flr,
700       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
701           remote_mgr);
702 
703   // The server_ is not const since we release it when the context is destroyed.
704   // Therefore the server_ object is not marked as const (even though it should
705   // be).
706   std::unique_ptr<ServerInterface> server_;
707   WorkerEnv* worker_env_ = nullptr;
708   std::shared_ptr<WorkerSession> worker_session_;
709 
710   mutable mutex remote_state_mu_;
711 
712   uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
713   // The view id of an eager context should be set to 0 when context is created,
714   // and continuously incremented when context with the same context_id gets
715   // updated. The view id should be consistent between master and workers.
716   uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
717   std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
718   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
719       TF_GUARDED_BY(remote_state_mu_);
720 
721   int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
722   std::atomic<int> sleep_for_secs_;
723 
724   std::unique_ptr<Thread> keep_alive_thread_;
725   mutex keep_alive_thread_shutdown_mu_;
726   condition_variable keep_alive_thread_cv_;
727   bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
728 
729   std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
730       remote_mgr_;
731   bool is_master_ TF_GUARDED_BY(remote_state_mu_);
732 
733   // Maps from a remote worker to a list of parsed device filters.
734   std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
735       cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
736 
737   // A distributed manager that helps setup, update, and check liveness of
738   // member tasks in the cluster.
739   std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
740 
741 #endif  // IS_MOBILE_PLATFORM
742 
743   // For a multi device function, the target device of each input is unknown
744   // until the function is instantiated on the default function device.
745   // If false, eagerly copy all remote inputs to the default function device;
746   // if true, lazily copy remote inputs to their target devices to avoid
747   // redundant copies.
748   bool lazy_copy_function_remote_inputs_ = false;
749   bool use_send_tensor_rpc_;
750   const bool pin_small_ops_to_cpu_;
751 
752   // Function that will be invoked in destructor to deallocate resources related
753   // to this context.
754   std::function<void()> resource_deallocator_ = nullptr;
755   bool run_eager_op_as_function_;
756 };
757 
ContextFromInterface(ImmediateExecutionContext * context)758 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
759   return down_cast<EagerContext*>(context);
760 }
761 
762 namespace internal {
763 struct EagerContextDeleter {
operatorEagerContextDeleter764   void operator()(EagerContext* p) const {
765     if (p != nullptr) {
766       p->Release();
767     }
768   }
769 };
770 }  // namespace internal
771 
772 using EagerContextPtr =
773     std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
774 
775 }  // namespace tensorflow
776 
777 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
778