• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/core/common_runtime/composite_device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/eager/custom_device.h"
34 #include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
35 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
36 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/example/example.pb.h"
41 #include "tensorflow/core/framework/collective.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/log_memory.h"
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/core/stringpiece.h"
48 #include "tensorflow/core/lib/core/threadpool.h"
49 #include "tensorflow/core/lib/gtl/flatmap.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/fingerprint.h"
56 #include "tensorflow/core/platform/mutex.h"
57 #include "tensorflow/core/platform/platform.h"
58 #include "tensorflow/core/platform/status.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/platform/threadpool.h"
61 #include "tensorflow/core/public/session_options.h"
62 #include "tensorflow/core/public/version.h"
63 #include "tensorflow/core/util/device_name_utils.h"
64 
65 // "tensorflow/core/platform/platform.h" must be included first before using
66 // IS_MOBILE_PLATFORM.
67 #if !defined(IS_MOBILE_PLATFORM)
68 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
69 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
70 #include "tensorflow/core/distributed_runtime/server_lib.h"
71 #include "tensorflow/core/distributed_runtime/worker_cache.h"
72 #include "tensorflow/core/distributed_runtime/worker_env.h"
73 #endif  // !IS_MOBILE_PLATFORM
74 
75 namespace tensorflow {
76 
77 namespace eager {
78 // We need this forward declaration because we have circular dependency:
79 // Context -> RemoteMgr -> TensorHandle -> Context.
80 // TODO(fishx): Remove this once we remove Context dependency in TensorHandle.
81 class RemoteMgr;
82 }  // namespace eager
83 
84 class TensorHandle;
85 class EagerOperation;
86 
87 class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
88  public:
89   static constexpr uint64 kInvalidContextId = 0;
90 
NewContextId()91   static uint64 NewContextId() {
92     uint64 context_id = random::New64();
93     while (context_id == kInvalidContextId) {
94       context_id = random::New64();
95     }
96     return context_id;
97   }
98 
99   EagerContext(
100       const SessionOptions& opts,
101       ContextDevicePlacementPolicy default_device_placement_policy, bool async,
102       /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned,
103       /*const*/ Rendezvous* rendezvous,
104       DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
105       CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
106       bool run_eager_op_as_function = false, bool jit_compile_rewrite = false);
107 
Release()108   void Release() override { Unref(); }
109 
110   AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
111   AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
112   AbstractTensorInterface* CreateInt32Scalar(int32_t value) override;
113   AbstractTensorInterface* CreateFloatScalar(float value) override;
114   AbstractTensorInterface* CreateDoubleScalar(double value) override;
115   AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override;
116   AbstractTensorInterface* CreateStringScalar(
117       tensorflow::tstring value) override;
118   AbstractTensorInterface* CreateComplex128Scalar(
119       tensorflow::complex128 value) override;
120   AbstractTensorInterface* CreateBoolScalar(bool value) override;
121 
122   AbstractTensorInterface* CreateTensor(
123       DataType dtype, absl::Span<const int64_t> dim_sizes) override;
124   AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
125                                         int num_dims, void* data, size_t len,
126                                         MemoryReleaser memory_releaser,
127                                         void* memory_releaser_arg) override;
128 
129   ImmediateExecutionTensorHandle* CreateLocalHandle(
130       AbstractTensorInterface* t) override;
131   // Create an abstract tensor handle from tensorflow::Tensor.
132   ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
133       tensorflow::Tensor& t, const char* d_name) override;
134   ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
135       ImmediateExecutionTensorHandle* handle, const char* device_name,
136       Status* status) override;
137   ImmediateExecutionOperation* CreateOperation() override;
138 
139   // This is a virtual helper function to convert TFRT TensorHandle to
140   // tensorflow::TensorHandle. In current runtime EagerContext, just forward
141   // the input since the input tensor handle is already a
142   // tensorflow::TensorHandle.
143   ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
144       ImmediateExecutionTensorHandle* handle) override;
145 
146   Status RegisterFunction(AbstractFunction* f) override;
147 
148   bool UsesTFRT() override;
149 
150   bool RunEagerOpAsFunction() const;
151 
152   void SetRunEagerOpAsFunction(bool enable) override;
153 
154   bool JitCompileRewrite() const;
155 
156   void SetJitCompileRewrite(bool enable) override;
157 
158   void ListDevices(std::vector<DeviceAttributes>* devices) override;
159 
160   Status AddDevices(std::vector<std::unique_ptr<Device>> devices) override;
161 
GetThreadPool()162   thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); }
163 
164   // Returns the function library runtime for the given device.
func_lib(const Device * d)165   FunctionLibraryRuntime* func_lib(const Device* d) const {
166     return pflr_->GetFLR(d->name());
167   }
168 
pflr()169   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
170 
runner()171   std::function<void(std::function<void()>)>* runner() { return &runner_; }
172 
173   // Specify a executor for this thread.
174   void SetExecutorForThread(EagerExecutor* executor) override;
175 
prioritized_device_type_list()176   const std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list()
177       const {
178     mutex_lock l(device_type_list_mu_);
179     return prioritized_device_type_list_;
180   }
181 
182   // Clear pending nodes in thread executors and kernel caches.
183   void ClearCachesAndThreadExecutors() override;
184   // Clear pending nodes in default executor and kernel caches.
185   void ClearCachesAndDefaultExecutor();
186 
187   // Sets the device placement policy for the current thread.
188   void SetThreadLocalDevicePlacementPolicy(
189       ContextDevicePlacementPolicy policy) override;
190 
191   // Returns the device placement policy for the current thread.
192   ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override;
193 
194   // Select an appropriate device for an operation.
195   //
196   // Given the preferred device for the operation, and the node_def, finds the
197   // best suitable device for the operation in this context.
198   //
199   // The preferred device is specified as a `ParsedName` containing the elements
200   // (details) that the resulting device should match. If there are no such
201   // devices, and the context currently allows soft device placement, a suitable
202   // device not matching `preferred` will be chosen.
203   //
204   // The chosen device is stored in the `device` argument. The argument is not
205   // modified unless this method returns `Status::OK()`.
206   Status SelectDevice(DeviceNameUtils::ParsedName preferred,
207                       const NodeDef& ndef, Device** out) const;
208 
209   // TODO(mdan): Rename to ContainsFunction.
210   bool FindFunctionByName(const string& name) const;
211 
212   Status FindFunctionOpData(const string& name,
213                             const tensorflow::OpRegistrationData** op_data);
214 
215   const FunctionDef* FindFunctionDef(const string& name) const override;
216 
HostCPU()217   Device* HostCPU() const { return host_cpu_device_; }
CanonicalDevice(Device * d)218   Device* CanonicalDevice(Device* d) const {
219     return HostCPU() == d ? nullptr : d;
220   }
HostCPUParsedName()221   const DeviceNameUtils::ParsedName& HostCPUParsedName() const override {
222     return HostCPU()->parsed_name();
223   }
224 
HostCPUName()225   const string& HostCPUName() const override { return HostCPU()->name(); }
226 
GetGraphCollector()227   GraphCollector* GetGraphCollector() { return &graph_collector_; }
228 
229   EagerExecutor& Executor() override;
230 
231   // Add the given `fdef` to the local FunctionLibraryDefinition. And add an
232   // entry to the KernelAndDevice cache for it if it's not exist.
233   Status AddFunctionDef(const FunctionDef& fdef) override;
234 
235   Status AddFunctionDefWithStackTraces(
236       const FunctionDef& fdef, const StackTracesMap& stack_traces) override;
237 
238   // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
239   // it to the local FunctionLibraryDefinition as well, but no need to add it
240   // to the KernelAndDevice cache since they won't be executed as
241   // KernelAndDevices.
242   Status AddFunctionDef(const FunctionDef& fdef,
243                         const FunctionDefLibrary& library,
244                         bool add_to_local_only = false,
245                         const StackTracesMap& stack_traces = {});
246 
247   const FunctionDef* GetFunctionDef(const string& function_name);
248 
249   std::vector<string> ListFunctionNames() override;
250 
251   Status RemoveFunction(const string& func) override;
252 
253   // Wait for pending nodes to be finished in local executors (including context
254   // default executor and thread executors) and executors on remote workers.
255   // Return combined status of remote executors. If there are multiple errors,
256   // the Status code will be the same as the first remote executor that has
257   // errors, and the error message will be combined from all executors.
258   Status SyncExecutors();
259 
AsyncWait()260   Status AsyncWait() override { return SyncExecutors(); }
261 
262   core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key);
263   Device* GetCachedDevice(Fprint128 device_cache_key);
264 
265   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
266   void AddDeviceToCache(Fprint128 device_cache_key, Device* device);
267 
LogDevicePlacement()268   bool LogDevicePlacement() const { return log_device_placement_; }
SetLogDevicePlacement(bool enable)269   void SetLogDevicePlacement(bool enable) override {
270     log_device_placement_ = enable;
271   }
272 
273   // When tensor transfer across functions/eager executions using send/recv ops
274   // are required, `reuse_rendezvous_for_functions_` can be set to true so that
275   // function executions and eager executions use the same rendezvous instance,
276   // instead of creating new instance per function calls.
SetReuseRendezvousForFunctions(bool reuse_rendezvous_for_functions)277   void SetReuseRendezvousForFunctions(
278       bool reuse_rendezvous_for_functions) override {
279     reuse_rendezvous_for_functions_ = reuse_rendezvous_for_functions;
280   }
GetReuseRendezvousForFunctions()281   bool GetReuseRendezvousForFunctions() const {
282     return reuse_rendezvous_for_functions_;
283   }
reuse_rendezvous_for_functions_mu()284   mutex* reuse_rendezvous_for_functions_mu() {
285     return &reuse_rendezvous_for_functions_mu_;
286   }
287 
AllowSoftPlacement()288   bool AllowSoftPlacement() const { return allow_soft_placement_; }
SetAllowSoftPlacement(bool enable)289   void SetAllowSoftPlacement(bool enable) override {
290     allow_soft_placement_ = enable;
291   }
LogMemory()292   bool LogMemory() const { return log_memory_; }
293 
GetRendezvous()294   Rendezvous* GetRendezvous() const { return rendezvous_; }
295 
ResetGlobalRendezvousForFunction()296   void ResetGlobalRendezvousForFunction() override {
297     mutex_lock l(global_rendezvous_mu_);
298     // Remove the global rendezvous instance from the local rendezvous table
299     // if it uses local rendezvous type, which forces EagerContext to create a
300     // new local rendezvous instance in the table.
301     local_rendezvous_table_->Remove(-1);
302     global_rendezvous_for_functions_ =
303         core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
304   }
305 
306   // Returns the global_rendezvous_for_functions' underlying LocalRendezvous'
307   // status. If the underlying Rendezvous is not in the local_rendezvous_table_
308   // returns OK.
309   Status GetGlobalRendezvousForFunctionLocalRendezvousStatus();
310 
311   // Returns a function which maps from step_id to rendezvous. This closure
312   // respects the value of `SetReuseRendezvousForFunctions` at the time the
313   // closure was created, which allows the setting to be toggled around async op
314   // launches.
315   //
316   // The caller of the returned function owns a reference to the resulting
317   // Rendezvous.
RendezvousCreator()318   std::function<Rendezvous*(int64_t)> RendezvousCreator() {
319     // There is an implicit assumption that the global_rendezvous_for_functions_
320     // is always an IntraProcessRendezvous to match the behaviour of the
321     // EagerContext's rendezvous.
322     // Ref: tensorflow/c/eager/c_api.cc;l=143;rcl=396387348
323     // If a cross process kernel needs a rendezvous a new InterProcessRendezvous
324     // should be created.
325     if (reuse_rendezvous_for_functions_ && rendezvous_creator_ == nullptr &&
326 #if !defined(IS_MOBILE_PLATFORM)
327         worker_env_ == nullptr &&
328 #endif
329         remote_device_mgr() == nullptr) {
330       return [this](int64_t step_id) {
331         mutex_lock l(global_rendezvous_mu_);
332         global_rendezvous_for_functions_->Ref();
333         return global_rendezvous_for_functions_.get();
334       };
335     } else {
336       return [this](int64_t step_id) { return CreateRendezvous(step_id); };
337     }
338   }
339 
collective_executor_mgr()340   CollectiveExecutorMgrInterface* collective_executor_mgr() {
341     return collective_executor_mgr_.Get();
342   }
GetCollectiveExecutorHandle()343   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
344     return std::unique_ptr<CollectiveExecutor::Handle>(
345         new CollectiveExecutor::Handle(
346             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
347   }
348 
local_device_mgr()349   tensorflow::DeviceMgr* local_device_mgr() const {
350     return local_device_manager_.Get();
351   }
remote_device_mgr()352   const tensorflow::DynamicDeviceMgr* remote_device_mgr() const {
353     return remote_device_manager_.Get();
354   }
355 
GetOwnedRemoteDeviceMgr()356   tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() {
357     return remote_device_manager_.GetOwned();
358   }
359 
ListLocalTfDevices()360   std::vector<Device*> ListLocalTfDevices() override {
361     return local_device_mgr()->ListDevices();
362   }
363 
364   std::vector<Device*> ListAllTfDevices() override;
365 
366   // TODO(apassos) clean up RunMetadata storage.
MetadataMu()367   mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
368   bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
369   void SetShouldStoreGraphs(bool value) override;
RunMetadataProto()370   RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
371     return run_metadata_.get();
372   }
373   std::unique_ptr<RunMetadata> ExportRunMetadata() override
374       TF_LOCKS_EXCLUDED(metadata_mu_);
375 
376   void StartStep() override;
377   void EndStep() override;
378   ScopedStepContainer* StepContainer();
379 
FuncLibDef()380   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
381 
382 #if !defined(IS_MOBILE_PLATFORM)
383   // Assign the EagerClient pointer to `client` based on the given device / task
384   // name, and increment the refcount of the client. The reference ownership is
385   // transferred to the caller, and the unref should automatically happen when
386   // destructing the RefCountPtr object at the caller's side.
387   // `client` must not be initialized or holding a reference of another object
388   // before calling this method.
389   Status GetClient(Device* device,
390                    core::RefCountPtr<eager::EagerClient>* client);
391   Status GetClient(const DeviceNameUtils::ParsedName& device_name,
392                    core::RefCountPtr<eager::EagerClient>* client);
393   Status GetClient(const string& remote_task,
394                    core::RefCountPtr<eager::EagerClient>* client);
395 
396   uint64 GetContextId() const;
397   uint64 GetContextViewId() const;
398   void IncrementContextViewId();
399 
400   Status EnableCollectiveOps(const ServerDef& server_def) override;
401 
402   // TODO(nareshmodi): Encapsulate remote state into a separate
403   // class/struct.
404   //
405   // Enables the eager context to communicate with remote devices. When
406   // initializing with this method, this context will be the primary context,
407   // which will kill all its remote contexts in shutdown.
408   //
409   // - server: A ServerInterface that exports the tensorflow.WorkerService.
410   // Note that this class expects the server to already have been started.
411   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
412   // communicate with remote eager services.
413   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
414   // (should contain no local devices).
415   // - remote_contexts: A vector containing task names.
416   // TODO(b/184375824): clean up parameter order for better readability.
417   Status InitializeRemoteMaster(
418       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
419       std::shared_ptr<WorkerSession> worker_session,
420       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
421       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
422       const std::vector<string>& remote_contexts, uint64 context_id,
423       /*const*/ Rendezvous* r, /*const*/ DeviceMgr* local_device_mgr,
424       int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
425       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
426           remote_mgr);
427 
428   // Update an existing master context with a new set of remote workers (i.e., a
429   // new "view" of cluster membership. Similar to InitializeRemoteMaster but
430   // this will keep the current context_id and increment a context_view_id, will
431   // keep the current resource manager so that resources from the previous view
432   // can still be accessed, and will automatically register existing functions
433   // if there are newly added hosts.
434   Status UpdateRemoteMaster(
435       uint64 context_id,
436       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
437       const std::vector<string>& add_remote_contexts,
438       const std::vector<string>& remove_remote_contexts);
439 
440   // Similar with InitializeRemoteMaster but this context will not kill remote
441   // contexts in shutdown.
442   Status InitializeRemoteWorker(
443       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
444       DynamicDeviceMgr* remote_device_mgr,
445       const std::vector<string>& remote_contexts, uint64 context_id,
446       uint64 context_view_id,
447       std::function<Rendezvous*(const int64_t)> rendezvous_creator,
448       DistributedFunctionLibraryRuntime* cluster_flr,
449       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
450           remote_mgr,
451       std::function<void()> resource_deallocator);
452 
453   // Similar with InitializeRemoteWorker but will reuse existing context and
454   // increment context_view_id.
455   Status UpdateRemoteWorker(
456       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
457       const std::vector<string>& remote_contexts, uint64 context_id);
458 
459   Status StoreCollectiveOpsServer(
460       std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
461       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
462 
463   // For the specified remote worker, preprocess and set its device filters.
464   Status SetRemoteDeviceFilters(const string& remote_worker,
465                                 const std::vector<string>& device_filters);
466 
467   // For the specified remote worker, apply the stored device filters to the
468   // list of device attributes following these rules:
469   // (1) if the remote worker does not have device filters, all devices are
470   //     visible to the worker;
471   // (2) if the device is on the remote worker, then it is visible;
472   // (3) if the device matches at least one device filter, then it is visible.
473   // The result is saved as a boolean vector of the same length (i.e.,
474   // filtered_device_mask) indicating whether each of the devices is visible to
475   // the remote worker.
476   void FilterDevicesForRemoteWorkers(
477       const string& remote_worker,
478       const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
479       std::vector<bool>* filtered_device_mask);
480 
481   // TODO(fishx): Remove the custom deleter once we remove forward declaration.
482   const std::unique_ptr<eager::RemoteMgr,
483                         std::function<void(eager::RemoteMgr*)>>&
RemoteMgr()484   RemoteMgr() {
485     return remote_mgr_;
486   }
487 
488   // If true, then tensors should be shipped across processes via the
489   // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be
490   // used instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()491   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
492 
GetServer()493   tensorflow::ServerInterface* GetServer() { return server_.get(); }
494 
495   // For LLVM style RTTI.
classof(const AbstractContext * ptr)496   static bool classof(const AbstractContext* ptr) {
497     return ptr->getKind() == kEager;
498   }
499 
500   // Function to support distributed C API.
SetDistributedManager(std::unique_ptr<ImmediateExecutionDistributedManager> distributed)501   void SetDistributedManager(
502       std::unique_ptr<ImmediateExecutionDistributedManager> distributed)
503       override {
504     distributed_manager_ = std::move(distributed);
505   }
GetDistributedManager()506   ImmediateExecutionDistributedManager* GetDistributedManager() override {
507     return distributed_manager_.get();
508   }
509 
510   // May only be used during multi-client setup so that a RemoteRendezvous
511   // can be initialized instead of defaulting to the IntraProcessRendezvous.
512   void SetWorkerEnv(WorkerEnv* worker_env,
513                     std::shared_ptr<WorkerSession> worker_session);
514 #endif  // IS_MOBILE_PLATFORM
515 
516   // Closes remote eager contexts, waits for all RPCs to finish, and
517   // destroys the EagerClientCache. No RPCs can be made through this context
518   // after this method has been called.
519   // This method exists to aid a clean shutdown. It causes all RPCs to finish
520   // and remote TensorHandles to release their references to this context.
521   // To avoid deadlocks, this method must not be called on the thread
522   // processing RPCs because it makes RPCs and waits for their completion.
523   //
524   // On mobile, it just cleans the caches.
525   void WaitForAndCloseRemoteContexts();
526 
PinSmallOpsToCPU()527   bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; }
528 
TFEnv()529   tensorflow::Env* TFEnv() const { return env_; }
530 
531   Status FindDeviceFromName(const char* device_name, Device** device) const;
532 
533   Status FindCompositeDeviceFromName(StringPiece device_name,
534                                      CompositeDevice** device) const;
535 
536   Status RegisterCustomDevice(const string& name,
537                               std::unique_ptr<CustomDevice> device) override;
538 
GetCustomDeviceOpHandler()539   CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
540     return custom_device_op_handler_;
541   };
542 
543   // Find or create a composite device with the given `underlying_devices` and
544   // `device_name` (if not empty).
545   Status FindOrCreateCompositeDevice(
546       const std::vector<string>& underlying_devices, const string& device_name,
547       CompositeDevice** composite_device);
548 
549   bool OnSameTask(const Device* first, const Device* second) const;
550   // Gets the CPU device on the task of device.
551   Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
552 
session_options()553   const SessionOptions& session_options() const { return opts_; }
554   void InitPrioritizedDeviceTypeList();
555 
556   // Re-assign cluster-FLR and re-initialize devices and FLR in process-FLR
557   void UpdateClusterFLRAndInitDevices(
558       DistributedFunctionLibraryRuntime* cluster_flr);
559 
560   // A constant representing the step id used for the global rendezvous.
561   // This is used to distibguish whether a user-specified step id should be set.
562   // Step id value of kGlobalRendezvous is reserved and should not be specified
563   // by the user.
564   static const int64_t kGlobalRendezvousId;
565 
566  private:
567   // The class for wrapping a map of step_id to local rendezvous instances.
568   class LocalRendezvousTable {
569    public:
570     LocalRendezvousTable() = default;
571     ~LocalRendezvousTable();
572 
573     IntraProcessRendezvous* FindOrCreate(int64_t step_id,
574                                          DeviceMgr* device_mgr);
575     IntraProcessRendezvous* Find(int64_t step_id);
576     void Remove(int64_t step_id);
577     void CleanUpAll();
578 
579    private:
580     mutable mutex table_lock_;
581     absl::flat_hash_map<int64_t, IntraProcessRendezvous*> table_
582         TF_GUARDED_BY(table_lock_);
583   };
584 
CreateRendezvous(int64_t step_id)585   Rendezvous* CreateRendezvous(int64_t step_id) const {
586     if (rendezvous_creator_ != nullptr) {
587       VLOG(6) << "Creating rendezvous using the rendezvous_creator_.";
588       return rendezvous_creator_(step_id);
589     }
590 
591 #if !defined(IS_MOBILE_PLATFORM)
592     if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
593       VLOG(6) << "Creating rendezvous using the worker_env's rendezvous_mgr.";
594       auto* remote_r = worker_env_->rendezvous_mgr->Find(step_id);
595       remote_r->Initialize(worker_session_.get()).IgnoreError();
596       return remote_r;
597     }
598 #endif
599 
600     if (remote_device_mgr() == nullptr) {
601       VLOG(6) << "Creating rendezvous using local_device_mgr.";
602       return local_rendezvous_table_->FindOrCreate(step_id, local_device_mgr());
603     }
604 
605     return nullptr;
606   }
607 
608   ~EagerContext() override;
609 
610   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
611   Status RegisterExistingFunctionsOnRemoteWorkers(
612       const std::vector<string>& remote_workers);
613 
614   void ResetPFLR(const DeviceMgr* device_mgr, Env* env,
615                  const ConfigProto* config, int graph_def_version,
616                  const FunctionLibraryDefinition* lib_def,
617                  const OptimizerOptions& optimizer_options,
618                  thread::ThreadPool* thread_pool = nullptr,
619                  DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
620 
621   void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr);
622   void UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr* device_mgr);
623 
624   void ClearResourceContainer(const string& name);
625 
626   template <typename T>
627   struct OwnedOrUnownedHelper {
628    public:
OwnedOrUnownedHelperOwnedOrUnownedHelper629     OwnedOrUnownedHelper() {}
630     explicit OwnedOrUnownedHelper(T* object, const bool owned = false) {
631       Reset(object, owned);
632     }
633 
ResetOwnedOrUnownedHelper634     void Reset(std::unique_ptr<T> object) {
635       owned_object = std::move(object);
636       unowned_object_ptr = nullptr;
637     }
638 
639     void Reset(T* object, const bool owned = false) {
640       if (owned) {
641         owned_object.reset(object);
642         unowned_object_ptr = nullptr;
643       } else {
644         owned_object.reset(nullptr);
645         unowned_object_ptr = object;
646       }
647     }
648 
OwnedOwnedOrUnownedHelper649     bool Owned() const { return owned_object != nullptr; }
650 
GetOwnedOwnedOrUnownedHelper651     T* GetOwned() const { return owned_object.get(); }
GetOwnedOrUnownedHelper652     T* Get() const {
653       return owned_object ? owned_object.get() : unowned_object_ptr;
654     }
655 
656     std::unique_ptr<T> owned_object = nullptr;
657     T* unowned_object_ptr = nullptr;
658   };
659 
660   SessionOptions opts_;
661   const ContextDevicePlacementPolicy default_device_placement_policy_;
662 
663   // Note: we cannot use C++11 thread_local here as there is no concept of a
664   // thread-local-object-local variable in C++11.
665   mutable mutex policy_map_mu_;
666   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
667       device_placement_policy_ TF_GUARDED_BY(policy_map_mu_);
668 
669   // This device manager maintains only the local devices on this worker.
670   OwnedOrUnownedHelper<DeviceMgr> local_device_manager_;
671   // Maintain copy of all previously created local device managers.
672   std::vector<std::unique_ptr<DeviceMgr>> old_local_device_managers_;
673 
674   // Unowned DynamicDeviceMgr is set on remote worker to allow running
675   // multi-device function on remote worker.
676   // This device manager maintains all the devices (including both local and
677   // remote to this worker) in the cluster.
678   OwnedOrUnownedHelper<DynamicDeviceMgr> remote_device_manager_;
679 
680   Device* host_cpu_device_;  // Owned by device_manager
681   mutable mutex device_type_list_mu_;
682   std::shared_ptr<std::vector<DeviceType>> prioritized_device_type_list_
683       TF_GUARDED_BY(device_type_list_mu_);
684   Rendezvous* rendezvous_;
685   std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
686   CustomDeviceOpHandler custom_device_op_handler_;
687 
688   mutable mutex composite_devices_mu_;
689   // Maps from the fingerprint of a set of device names to a virtual
690   // CompositeDevice.
691   // TODO(b/145922293): Consider taking device names as keys.
692   absl::flat_hash_map<uint64, std::unique_ptr<CompositeDevice>>
693       composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_);
694 
Global()695   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
696 
697   std::unique_ptr<thread::ThreadPool> thread_pool_;
698 
699   // EagerContext owns the DistributedFunctionLibraryRuntime(
700   // EagerClusterFunctionLibraryRuntime) if using EagerService for remote
701   // function execution (lazy_copy_function_remote_inputs_=true).
702   OwnedOrUnownedHelper<DistributedFunctionLibraryRuntime> cluster_flr_;
703   // One FunctionLibraryRuntime per device.
704   // func_libs[i] is the FunctionLibraryRuntime corresponding to
705   // session->devices[i].
706   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
707 
708   std::function<void(std::function<void()>)> runner_;
709 
710   mutex cache_mu_;
711   mutex device_cache_mu_;
712   struct RegisteredFunction : public core::RefCounted {
~RegisteredFunctionRegisteredFunction713     ~RegisteredFunction() override {}
714 
715     std::unique_ptr<std::vector<Fprint128>> cached_kernel_keys;
716   };
717   std::unordered_map<Fprint128, core::RefCountPtr<KernelAndDevice>,
718                      Fprint128Hasher>
719       kernel_cache_ TF_GUARDED_BY(cache_mu_);
720   std::unordered_map<string, RegisteredFunction*> registered_functions_
721       TF_GUARDED_BY(cache_mu_);
722   absl::flat_hash_map<Fprint128, Device*, Fprint128Hasher> device_cache_
723       TF_GUARDED_BY(device_cache_mu_);
724 
725   // Whether we should compute RunMetadata.
726   std::atomic<bool> should_store_graphs_{false};
727   mutex metadata_mu_;
728   std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
729   GraphCollector graph_collector_;
730   std::atomic<bool> log_device_placement_;
731   std::atomic<bool> allow_soft_placement_;
732 
733   // Information related to step containers.
734   std::atomic<int> num_active_steps_;
735   std::unique_ptr<ScopedStepContainer> step_container_
736       TF_GUARDED_BY(metadata_mu_);
737 
738   EagerExecutor default_executor_;
739   mutable mutex executor_map_mu_;
740   // Not owned.
741   std::unordered_map<std::thread::id, EagerExecutor*> thread_local_executor_
742       TF_GUARDED_BY(executor_map_mu_);
743   std::unordered_map<std::thread::id, std::unordered_set<EagerExecutor*>>
744       has_cleanup_ TF_GUARDED_BY(executor_map_mu_);
745 
746   const bool log_memory_;
747 
748   // The table of local rendezvous instances for intra-process communication.
749   // This make sures only one local rendezvous instance exists per step id.
750   std::unique_ptr<LocalRendezvousTable> local_rendezvous_table_;
751 
752   // Whether to use same rendezvous instance across function/eager executions.
753   std::atomic<bool> reuse_rendezvous_for_functions_{false};
754   mutable mutex global_rendezvous_mu_;
755   core::RefCountPtr<Rendezvous> global_rendezvous_for_functions_
756       TF_GUARDED_BY(global_rendezvous_mu_);
757   mutex reuse_rendezvous_for_functions_mu_;
758 
759   Env* const env_;
760 
761   OwnedOrUnownedHelper<CollectiveExecutorMgrInterface> collective_executor_mgr_;
762 
763 #if !defined(IS_MOBILE_PLATFORM)
764   std::vector<string> GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_);
765   bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_);
766   void CloseAndClearAllRemoteContexts();
767   void CloseRemoteContexts(const std::vector<string>& remote_contexts,
768                            uint64 context_id, uint64 context_view_id);
769 
770   // TODO(b/184375824): clean up parameter order for better readability.
771   Status SetMasterContextState(
772       std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
773       std::shared_ptr<WorkerSession> worker_session,
774       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
775       std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
776       uint64 context_id, uint64 context_view_id, /*const*/ Rendezvous* r,
777       /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs,
778       DistributedFunctionLibraryRuntime* cluster_flr,
779       std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
780           remote_mgr);
781 
782   // The server_ is not const since we release it when the context is destroyed.
783   // Therefore the server_ object is not marked as const (even though it should
784   // be).
785   std::unique_ptr<ServerInterface> server_;
786   WorkerEnv* worker_env_ = nullptr;
787   std::shared_ptr<WorkerSession> worker_session_;
788 
789   mutable mutex remote_state_mu_;
790 
791   uint64 context_id_ TF_GUARDED_BY(remote_state_mu_);
792   // The view id of an eager context should be set to 0 when context is created,
793   // and continuously incremented when context with the same context_id gets
794   // updated. The view id should be consistent between master and workers.
795   uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_);
796   std::vector<string> remote_contexts_ TF_GUARDED_BY(remote_state_mu_);
797   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_
798       TF_GUARDED_BY(remote_state_mu_);
799 
800   int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_);
801   std::atomic<int> sleep_for_secs_;
802 
803   std::unique_ptr<Thread> keep_alive_thread_;
804   mutex keep_alive_thread_shutdown_mu_;
805   condition_variable keep_alive_thread_cv_;
806   bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
807 
808   std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
809       remote_mgr_;
810   bool is_master_ TF_GUARDED_BY(remote_state_mu_);
811 
812   // Maps from a remote worker to a list of parsed device filters.
813   std::unordered_map<string, std::vector<DeviceNameUtils::ParsedName>>
814       cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_);
815 
816   // A distributed manager that helps setup, update, and check liveness of
817   // member tasks in the cluster.
818   std::unique_ptr<ImmediateExecutionDistributedManager> distributed_manager_;
819 
820 #endif  // IS_MOBILE_PLATFORM
821 
822   // For a multi device function, the target device of each input is unknown
823   // until the function is instantiated on the default function device.
824   // If false, eagerly copy all remote inputs to the default function device;
825   // if true, lazily copy remote inputs to their target devices to avoid
826   // redundant copies.
827   bool lazy_copy_function_remote_inputs_ = false;
828   bool use_send_tensor_rpc_;
829   const bool pin_small_ops_to_cpu_;
830 
831   // Function that will be invoked in destructor to deallocate resources related
832   // to this context.
833   std::function<void()> resource_deallocator_ = nullptr;
834   bool run_eager_op_as_function_;
835   bool jit_compile_rewrite_;
836 };
837 
ContextFromInterface(ImmediateExecutionContext * context)838 inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
839   return down_cast<EagerContext*>(context);
840 }
841 
842 namespace internal {
843 struct EagerContextDeleter {
operatorEagerContextDeleter844   void operator()(EagerContext* p) const {
845     if (p != nullptr) {
846       p->Release();
847     }
848   }
849 };
850 }  // namespace internal
851 
852 using EagerContextPtr =
853     std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
854 
855 // Sets the EagerContext owned by the current Python eager Context (see
856 // TFE_Py_SetEagerContext in python/eager/pywrap_tfe.h). This is always called
857 // in tandem with TFE_Py_SetEagerContext (but not called by it, because its
858 // py_context argument is opaque).
859 //
860 // Do not use this function in production. It is only intended for testing.
861 // (see _reset_context in context.py).
862 //
863 // Not thread-safe.
864 void SetCEagerContext(EagerContext* ctx);
865 
866 // Returns the EagerContext owned by the current Python eager Context (see
867 // TFE_Py_SetEagerContext in pywrap_tfe.h).
868 //
869 // Not thread-safe.
870 EagerContext* GetCEagerContext();
871 
872 }  // namespace tensorflow
873 
874 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
875