• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "tensorflow/c/eager/immediate_execution_context.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/nccl/collective_communicator.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/platform.h"
37 // clang-format on
38 
39 #include "tensorflow/c/tf_tensor.h"
40 #include "tensorflow/c/tf_tensor_internal.h"
41 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
42 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
43 #include "tensorflow/core/common_runtime/colocation_graph.h"
44 #include "tensorflow/core/common_runtime/device_resolver_local.h"
45 #include "tensorflow/core/common_runtime/device_set.h"
46 #include "tensorflow/core/common_runtime/process_util.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/graph_def_util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/protobuf/config.pb.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/device_name_utils.h"
53 #if !defined(IS_MOBILE_PLATFORM)
54 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
55 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
56 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
57 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
58 #include "tensorflow/core/distributed_runtime/session_mgr.h"
59 #endif  // !IS_MOBILE_PLATFORM
60 #include "tensorflow/core/framework/resource_mgr.h"
61 #include "tensorflow/core/lib/core/blocking_counter.h"
62 #include "tensorflow/core/lib/monitoring/gauge.h"
63 #include "tensorflow/core/util/env_var.h"
64 
65 namespace tensorflow {
66 
67 namespace {
68 // This object tracks the EagerContext owned by global_py_eager_context in
69 // pywrap_tfe_src.cc. Since the vast majority of the Python API is dependent on
70 // that global_py_eager_context (including memory management), the Py object
71 // owns the C object, so this pointer is non-owning.
72 EagerContext* global_c_eager_context = nullptr;
73 
74 }  // namespace
75 
SetCEagerContext(EagerContext * ctx)76 void SetCEagerContext(EagerContext* ctx) { global_c_eager_context = ctx; }
77 
GetCEagerContext()78 EagerContext* GetCEagerContext() { return global_c_eager_context; }
79 
80 namespace {
81 
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)82 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
83   bool val;
84   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
85     return val;
86   }
87   return default_val;
88 }
89 
90 auto* eager_context_created =
91     monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
92                                     "True if an eager context was created.");
93 
94 }  // namespace
95 
96 const int64_t EagerContext::kGlobalRendezvousId = -1;
97 
98 // Find the rendezvous instance corresponding to the step id, or create a
99 // new instance if not existing.
FindOrCreate(int64_t step_id,DeviceMgr * device_mgr)100 IntraProcessRendezvous* EagerContext::LocalRendezvousTable::FindOrCreate(
101     int64_t step_id, DeviceMgr* device_mgr) {
102   mutex_lock l(table_lock_);
103   auto iter = table_.find(step_id);
104   if (iter == table_.end()) {
105     iter =
106         table_.insert({step_id, new IntraProcessRendezvous(device_mgr)}).first;
107     // Global rendezvous: ref-count should be 1 upon creation.
108     if (step_id == EagerContext::kGlobalRendezvousId) {
109       return iter->second;
110     }
111   }
112   iter->second->Ref();
113   return iter->second;
114 }
115 
Find(int64_t step_id)116 IntraProcessRendezvous* EagerContext::LocalRendezvousTable::Find(
117     int64_t step_id) {
118   mutex_lock l(table_lock_);
119   auto iter = table_.find(step_id);
120   if (iter == table_.end()) return nullptr;
121   iter->second->Ref();
122   return iter->second;
123 }
124 
Remove(int64_t step_id)125 void EagerContext::LocalRendezvousTable::Remove(int64_t step_id) {
126   mutex_lock l(table_lock_);
127   auto iter = table_.find(step_id);
128   if (iter != table_.end()) {
129     table_.erase(iter);
130   }
131 }
132 
CleanUpAll()133 void EagerContext::LocalRendezvousTable::CleanUpAll() {
134   mutex_lock l(table_lock_);
135   for (auto iter = table_.begin(); iter != table_.end(); iter++) {
136     // Unref all redezvous instance, except for global rendezvous,
137     // which is cleaned up elsewhere when necessary.
138     if (iter->first == -1) {
139       continue;
140     }
141     iter->second->Unref();
142   }
143 }
144 
~LocalRendezvousTable()145 EagerContext::LocalRendezvousTable::~LocalRendezvousTable() { CleanUpAll(); }
146 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr,CollectiveExecutorMgrInterface * collective_executor_mgr,bool run_eager_op_as_function,bool jit_compile_rewrite)147 EagerContext::EagerContext(
148     const SessionOptions& opts,
149     ContextDevicePlacementPolicy default_device_placement_policy, bool async,
150     DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
151     DistributedFunctionLibraryRuntime* cluster_flr,
152     CollectiveExecutorMgrInterface* collective_executor_mgr,
153     bool run_eager_op_as_function, bool jit_compile_rewrite)
154     : ImmediateExecutionContext(kEager),
155       opts_(opts),
156       default_device_placement_policy_(default_device_placement_policy),
157       local_device_manager_(device_mgr, device_mgr_owned),
158       host_cpu_device_(device_mgr->HostCPU()),
159       rendezvous_(rendezvous),
160       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
161       cluster_flr_(cluster_flr),
162       log_device_placement_(opts.config.log_device_placement()),
163       allow_soft_placement_(opts.config.allow_soft_placement()),
164       num_active_steps_(0),
165       step_container_(std::make_unique<ScopedStepContainer>(
166           0, [this](const string& name) { ClearResourceContainer(name); })),
167       default_executor_(async, /*enable_streaming_enqueue=*/true),
168       log_memory_(LogMemory::IsEnabled()),
169       env_(opts.env),
170       collective_executor_mgr_(collective_executor_mgr, /*owned=*/false),
171       use_send_tensor_rpc_(false),
172       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
173           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)),
174       run_eager_op_as_function_(run_eager_op_as_function),
175       jit_compile_rewrite_(jit_compile_rewrite) {
176   ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
177             &func_lib_def_, opts.config.graph_options().optimizer_options(),
178             thread_pool_.get(), cluster_flr);
179   // Starts exporting metrics through a platform-specific monitoring API (if
180   // provided). For builds using "tensorflow/tsl/platform/default", this is
181   // currently a no-op.
182   eager_context_created->GetCell()->Set(true);
183   InitPrioritizedDeviceTypeList();
__anon3eb91ef90402(std::function<void()> closure) 184   runner_ = [this](std::function<void()> closure) {
185     this->thread_pool_->Schedule(std::move(closure));
186   };
187 
188   run_metadata_ = std::make_unique<RunMetadata>();
189 
190 #if !defined(IS_MOBILE_PLATFORM)
191   context_id_ = kInvalidContextId;
192   context_view_id_ = 0;
193 #endif  // IS_MOBILE_PLATFORM
194 
195   // TODO(yuefengz): consider creating a new RpcCollectiveExecutorMgr each
196   // time.
197   if (collective_executor_mgr_.Get() == nullptr) {
198     collective_executor_mgr_.Reset(CreateProdLocalCollectiveExecutorMgr(
199         opts.config, local_device_mgr(),
200         MaybeCreateNcclCommunicator(opts.config)));
201   }
202 
203   // Initialization of local_rendezvous_table_ needs to happen before the
204   // initialization of global_rendezvous_for_functions_ because the latter
205   // depends on the former.
206   local_rendezvous_table_ = std::make_unique<LocalRendezvousTable>();
207   global_rendezvous_for_functions_ =
208       core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
209 }
210 
CreateInt64Scalar(int64_t value)211 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64_t value) {
212   return new TensorInterface(Tensor(value));
213 }
214 
CreateUint64Scalar(uint64 value)215 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
216   return new TensorInterface(Tensor(value));
217 }
218 
CreateInt32Scalar(int32_t value)219 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32_t value) {
220   return new TensorInterface(Tensor(value));
221 }
222 
CreateFloatScalar(float value)223 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
224   return new TensorInterface(Tensor(value));
225 }
226 
CreateDoubleScalar(double value)227 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
228   return new TensorInterface(Tensor(value));
229 }
230 
CreateHalfScalar(Eigen::half value)231 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
232   return new TensorInterface(Tensor(value));
233 }
234 
CreateStringScalar(tstring value)235 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
236   return new TensorInterface(Tensor(value));
237 }
238 
CreateComplex128Scalar(complex128 value)239 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
240     complex128 value) {
241   return new TensorInterface(Tensor(value));
242 }
243 
CreateBoolScalar(bool value)244 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
245   return new TensorInterface(Tensor(value));
246 }
247 
CreateTensor(DataType dtype,absl::Span<const int64_t> dim_sizes)248 AbstractTensorInterface* EagerContext::CreateTensor(
249     DataType dtype, absl::Span<const int64_t> dim_sizes) {
250   return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
251 }
252 
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)253 AbstractTensorInterface* EagerContext::CreateTensor(
254     DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
255     MemoryReleaser memory_releaser, void* memory_releaser_arg) {
256   TF_Tensor* tensor_wrapper =
257       TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
258                    memory_releaser, memory_releaser_arg);
259 
260   AbstractTensorInterface* result = nullptr;
261   std::swap(result, tensor_wrapper->tensor);
262   TF_DeleteTensor(tensor_wrapper);
263   return result;
264 }
265 
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)266 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
267                              const ConfigProto* config, int graph_def_version,
268                              const FunctionLibraryDefinition* lib_def,
269                              const OptimizerOptions& optimizer_options,
270                              thread::ThreadPool* thread_pool,
271                              DistributedFunctionLibraryRuntime* cluster_flr) {
272   Rendezvous::Factory rendezvous_factory{
273       [this](const int64_t step_id, const DeviceMgr*, Rendezvous** r) {
274         *r = CreateRendezvous(step_id);
275         return OkStatus();
276       }};
277   pflr_.reset(new ProcessFunctionLibraryRuntime(
278       device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
279       thread_pool, cluster_flr,
280       /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
281 }
282 
InitPrioritizedDeviceTypeList()283 void EagerContext::InitPrioritizedDeviceTypeList() {
284   DeviceSet ds;
285   for (Device* d : local_device_mgr()->ListDevices()) {
286     ds.AddDevice(d);
287   }
288   auto remote_device_manager = remote_device_mgr();
289   if (remote_device_manager != nullptr) {
290     for (Device* d : remote_device_manager->ListDevices()) {
291       ds.AddDevice(d);
292     }
293   }
294   mutex_lock l(device_type_list_mu_);
295   prioritized_device_type_list_ =
296       std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
297 }
298 
299 namespace {
300 // Using absl::StrJoin with lambda does not work in tf-lite builds.
301 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)302 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
303   std::vector<string> v;
304   v.reserve(devices.size());
305   for (const auto& p : devices) {
306     v.push_back(p.first->name());
307   }
308   return v;
309 }
310 
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)311 std::vector<string> DeviceTypesToString(
312     const PrioritizedDeviceTypeVector& types) {
313   std::vector<string> v;
314   v.reserve(types.size());
315   for (const auto& p : types) {
316     v.push_back(p.first.type_string());
317   }
318   return v;
319 }
320 
321 // Selects the "best" device that both exists and is supported.
322 //
323 // The `existing` argument specifies the available devices in the system, in
324 // priority order. The `supported` argument specifies the supported device types
325 // and their priorities, lower index types having higher priority.
326 // Currently the type priority defined by the `supported` parameter takes
327 // precedence over system device priorities from `existing`.
328 //
329 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)330 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
331                                  const PrioritizedDeviceVector& existing,
332                                  const PrioritizedDeviceTypeVector& supported) {
333   for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
334     for (const std::pair<Device*, int32>& prioritized_device : existing) {
335       Device* dev = prioritized_device.first;
336       if (DeviceType(dev->attributes().device_type()) ==
337               prioritized_type.first &&
338           DeviceNameUtils::IsCompleteSpecification(pattern,
339                                                    dev->parsed_name())) {
340         return dev;
341       }
342     }
343   }
344   return nullptr;
345 }
346 
347 }  // namespace
348 
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const349 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
350                                   const NodeDef& ndef, Device** out) const {
351   DCHECK(out != nullptr);
352 
353   PrioritizedDeviceTypeVector supported_devs;
354   auto device_type_list = prioritized_device_type_list();
355   TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
356       *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
357   if (supported_devs.empty()) {
358     return errors::NotFound("Could not find device for node: ",
359                             errors::FormatNodeNameForError(ndef.name()), " = ",
360                             ndef.op(), "[", SummarizeAttrs(ndef), "]",
361                             "\nAll kernels registered for op ", ndef.op(),
362                             ":\n", KernelsRegisteredForOp(ndef.op()));
363   }
364 
365   // Select the first matching registered device from the supported device
366   // list. If nothing matches and soft placement is enabled, pick a suitable
367   // device from the available ones.
368   const auto pflr_device_set = pflr()->device_set();
369   const PrioritizedDeviceVector& existing =
370       pflr_device_set->prioritized_devices();
371   *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
372   if (*out != nullptr) {
373     return OkStatus();
374   }
375 
376   if (AllowSoftPlacement()) {
377     DeviceNameUtils::ParsedName soft_device_name = preferred;
378     soft_device_name.type.clear();
379     soft_device_name.has_type = false;
380     soft_device_name.has_id = false;
381     // TODO(b/148213746): Soft placement logic picks up another task if the
382     // requested does not exist.
383     *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
384     if (*out != nullptr) {
385       return OkStatus();
386     }
387   }
388 
389   if (DeviceNameUtils::HasSomeDetails(preferred)) {
390     return errors::InvalidArgument(
391         "Could not satisfy device specification '", preferred,
392         "'. enable_soft_placement=", AllowSoftPlacement(),
393         ". Supported device types [",
394         absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
395         "]. All available devices [",
396         absl::StrJoin(DevicesToString(existing), ", "), "].");
397   }
398   return errors::InvalidArgument(
399       "No supported device found in available devices [",
400       absl::StrJoin(DevicesToString(existing), ", "),
401       "]. enable_soft_placement=", AllowSoftPlacement(),
402       ". Supported devices types [",
403       absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
404 }
405 
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)406 void EagerContext::ResetClusterFLR(
407     DistributedFunctionLibraryRuntime* cluster_flr) {
408   cluster_flr_.Reset(cluster_flr, /*owned=*/true);
409 }
410 
UpdateClusterFLRAndInitDevices(DistributedFunctionLibraryRuntime * cluster_flr)411 void EagerContext::UpdateClusterFLRAndInitDevices(
412     DistributedFunctionLibraryRuntime* cluster_flr) {
413   ResetClusterFLR(cluster_flr);
414 
415   const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
416   ResetPFLR(
417       local_device_manager_.Get(), env_, /*config=*/config,
418       TF_GRAPH_DEF_VERSION, &func_lib_def_,
419       /*optimizer_options=*/
420       config ? config->graph_options().optimizer_options() : OptimizerOptions(),
421       thread_pool_.get(), cluster_flr_.Get());
422 }
423 
Executor()424 EagerExecutor& EagerContext::Executor() {
425   tf_shared_lock l(executor_map_mu_);
426   return *gtl::FindWithDefault(thread_local_executor_,
427                                std::this_thread::get_id(), &default_executor_);
428 }
429 
SetExecutorForThread(EagerExecutor * executor)430 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
431   tensorflow::mutex_lock l(executor_map_mu_);
432   if (executor == &default_executor_) {
433     thread_local_executor_.erase(std::this_thread::get_id());
434   } else {
435     auto thread_id = std::this_thread::get_id();
436     thread_local_executor_[thread_id] = executor;
437     auto& executors_with_cleanups = has_cleanup_[thread_id];
438     if (executors_with_cleanups.find(executor) ==
439         executors_with_cleanups.end()) {
440       executors_with_cleanups.insert(executor);
441       // If the executor is deleted before this context, we need to remove it
442       // from the map to avoid attempting to sync it in our destructor.
443       std::function<void()> cleanup([this, thread_id, executor]() {
444         {
445           tensorflow::mutex_lock l(executor_map_mu_);
446           auto existing = thread_local_executor_.find(thread_id);
447           if (existing != thread_local_executor_.end() &&
448               existing->second == executor) {
449             thread_local_executor_.erase(thread_id);
450           }
451           has_cleanup_[thread_id].erase(executor);
452           // Clears the global rendezvous after cleaning up the executor. This
453           // is needed when running in eager op as function mode because it
454           // re-uses the EagerContext's global_rendezvous_for_functions. The
455           // global rendezvous can end up in a bad state if any op ends in a
456           // bad state after execution.
457           if (!GetGlobalRendezvousForFunctionLocalRendezvousStatus().ok()) {
458             VLOG(6) << "global_rendezvous_for_functions_ is in bad state. "
459                        "Resetting.";
460             ResetGlobalRendezvousForFunction();
461           }
462         }
463       });
464       executor->AddCleanup(reinterpret_cast<intptr_t>(this),
465                            std::move(cleanup));
466     }
467   }
468 }
469 
ClearCachesAndThreadExecutors()470 void EagerContext::ClearCachesAndThreadExecutors() {
471   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
472   {
473     mutex_lock l(executor_map_mu_);
474     executors_copy = thread_local_executor_;
475   }
476   for (const auto& entry : executors_copy) {
477     entry.second->WaitForAllPendingNodes().IgnoreError();
478   }
479   ClearCachesAndDefaultExecutor();
480 }
481 
ClearCachesAndDefaultExecutor()482 void EagerContext::ClearCachesAndDefaultExecutor() {
483   // The executor stores pointers to kernels, so we need to make sure that no
484   // async eager ops are still executing. We lock the cache during this time
485   // as well.
486   mutex_lock ml(cache_mu_);
487   default_executor_.WaitForAllPendingNodes().IgnoreError();
488   kernel_cache_.clear();
489   for (auto& entry : registered_functions_) {
490     entry.second->cached_kernel_keys->clear();
491   }
492   {
493     mutex_lock dl(device_cache_mu_);
494     device_cache_.clear();
495   }
496   {
497     mutex_lock ml(metadata_mu_);
498     step_container_.reset(new ScopedStepContainer(
499         0, [this](const string& name) { ClearResourceContainer(name); }));
500   }
501 }
502 
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)503 void EagerContext::SetThreadLocalDevicePlacementPolicy(
504     ContextDevicePlacementPolicy policy) {
505   mutex_lock ml(policy_map_mu_);
506   VLOG(6) << "Setting device placement policy to: " << policy;
507   device_placement_policy_[std::this_thread::get_id()] = policy;
508 }
509 
GetDevicePlacementPolicy() const510 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
511   tf_shared_lock l(policy_map_mu_);
512   auto policy_map_it =
513       device_placement_policy_.find(std::this_thread::get_id());
514   if (policy_map_it != device_placement_policy_.end()) {
515     VLOG(6) << "ContextDevicePlacementPolicy: " << policy_map_it->second;
516     return policy_map_it->second;
517   }
518   VLOG(6) << "ContextDevicePlacementPolicy not found; returning default.";
519   return default_device_placement_policy_;
520 }
521 
522 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()523 std::vector<string> EagerContext::GetRemoteContexts() {
524   tf_shared_lock l(remote_state_mu_);
525   return remote_contexts_;
526 }
527 
IsRemoteContextsEmpty()528 bool EagerContext::IsRemoteContextsEmpty() {
529   tf_shared_lock l(remote_state_mu_);
530   return remote_contexts_.empty();
531 }
532 
CloseAndClearAllRemoteContexts()533 void EagerContext::CloseAndClearAllRemoteContexts() {
534   uint64 context_id;
535   uint64 context_view_id;
536   std::vector<string> remote_contexts_copy;
537   {
538     mutex_lock l(remote_state_mu_);
539     if (!is_master_) return;
540     context_id = context_id_;
541     context_view_id = context_view_id_;
542     context_id_ = kInvalidContextId;
543     // Forget the current view id and reset to the starting value 0.
544     context_view_id_ = 0;
545 
546     // Make a copy of remote targets to avoid holding the lock when sending
547     // close context requests.
548     remote_contexts_copy = remote_contexts_;
549     remote_contexts_.clear();
550   }
551   CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
552 }
553 
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)554 void EagerContext::CloseRemoteContexts(
555     const std::vector<string>& remote_contexts, uint64 context_id,
556     uint64 context_view_id) {
557   // Close all remote contexts.
558   eager::CloseContextRequest request;
559   request.set_context_id(context_id);
560   request.set_context_view_id(context_view_id);
561   // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
562   // request to closed remote workers.
563   std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
564   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
565 
566   int i = 0;
567   for (const auto& worker : remote_contexts) {
568     core::RefCountPtr<eager::EagerClient> client;
569     Status s = GetClient(worker, &client);
570 
571     client->CloseContextAsync(
572         &request, &responses[i],
573         [&worker, &counter, context_id](const Status& s) {
574           if (!s.ok()) {
575             LOG(ERROR) << "Unable to close remote context with ID "
576                        << context_id << " for worker: " << worker << " due to "
577                        << s.error_message();
578           }
579           counter.DecrementCount();
580         });
581     i++;
582   }
583 
584   counter.Wait();
585 }
586 
587 #endif  // !IS_MOBILE_PLATFORM
588 
WaitForAndCloseRemoteContexts()589 void EagerContext::WaitForAndCloseRemoteContexts() {
590   ClearCachesAndThreadExecutors();
591 
592 #if !defined(IS_MOBILE_PLATFORM)
593   {
594     mutex_lock l(keep_alive_thread_shutdown_mu_);
595     shutting_down_ = true;
596     keep_alive_thread_cv_.notify_all();
597   }
598   keep_alive_thread_.reset();
599 
600   if (!IsRemoteContextsEmpty()) {
601     CloseAndClearAllRemoteContexts();
602   }
603 
604   {
605     mutex_lock l(remote_state_mu_);
606 
607     default_executor_.ShutDown().IgnoreError();
608     std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
609     {
610       mutex_lock l(executor_map_mu_);
611       executors_copy = thread_local_executor_;
612     }
613     for (const auto& it : executors_copy) {
614       it.second->ShutDown().IgnoreError();
615     }
616 
617     // This shuts down the completion queue and joins the thread polling it.
618     // The thread exits only after the completion queue has been drained of all
619     // the events. These events' completion should invoke all remaining RPC
620     // callbacks.
621     // This also deletes all EagerClient instances. There should not be any
622     // references to EagerClients left after all RPCs and async ops have been
623     // finished.
624     remote_eager_workers_ = nullptr;
625   }
626 #endif  // !IS_MOBILE_PLATFORM
627 }
628 
~EagerContext()629 EagerContext::~EagerContext() {
630   // TODO(iga): Add a separate API method to shutdown EagerContext so that we
631   // don't send RPCs and block in destructor.
632   WaitForAndCloseRemoteContexts();
633 
634   // Custom devices may have obtained references to various context components
635   // (executors, thread pool). It's safer to run their destructors early.
636   custom_device_op_handler_.Clear();
637 
638   ClearCachesAndThreadExecutors();
639   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
640   {
641     mutex_lock l(executor_map_mu_);
642     executors_copy = thread_local_executor_;
643   }
644   for (const auto& entry : executors_copy) {
645     // Let the executor know that its cleanup closure is no longer valid.
646     entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
647   }
648   for (auto& entry : registered_functions_) {
649     while (!entry.second->Unref()) {
650       // remove all references.
651     }
652   }
653   registered_functions_.clear();
654 
655 #if !defined(IS_MOBILE_PLATFORM)
656   if (server_) {
657     // TODO(b/136478427): Fix this.
658     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
659                     "Servers don't support clean shutdown.";
660     // TODO(hanyangtay): Remove this teardown logic once gRPC server clean
661     // shutdown is supported.
662     if (server_->worker_env()->session_mgr != nullptr) {
663       // Tear down coordination service.
664       Status s = server_->StopCoordinationService();
665       if (!s.ok()) {
666         LOG(ERROR) << "Failed to stop coordination service: " << s;
667       }
668     }
669     server_.release();
670   }
671 
672   {
673     mutex_lock l(keep_alive_thread_shutdown_mu_);
674     shutting_down_ = true;
675     keep_alive_thread_cv_.notify_all();
676   }
677   keep_alive_thread_.reset();
678   if (!remote_contexts_.empty()) {
679     CloseAndClearAllRemoteContexts();
680   }
681 
682   // Clean up all the rendezvous instances created via EagerContext.
683   // Currently there are 3 cases in which a rendezvous instances is created:
684   // (1). Created through a rendezvous_creator passed to EagerContext.
685   // (2). Created through rendezvous_mgr.
686   // (3). Created within EagerContext using LocalRendezvousTable.
687   //
688   // Currently case-(3) is taken care of automatically when an EagerContext
689   // instance is deleted. The following code takes care of case-(2). Case-(1)
690   // is tricky as EagerContext does not have a way to access those rendezvous
691   // instances.
692   // TODO (tfrt-dev): Take care of case-(1) mentioned above.
693   if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
694     worker_env_->rendezvous_mgr->CleanupAll();
695   }
696 #endif  // !IS_MOBILE_PLATFORM
697 
698   if (rendezvous_) {
699     rendezvous_->Unref();
700   }
701   if (resource_deallocator_ != nullptr) {
702     resource_deallocator_();
703   }
704 }
705 
FindFunctionByName(const string & name) const706 bool EagerContext::FindFunctionByName(const string& name) const {
707   return func_lib_def_.Find(name) != nullptr;
708 }
709 
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)710 Status EagerContext::FindFunctionOpData(
711     const string& name, const tensorflow::OpRegistrationData** op_data) {
712   return func_lib_def_.LookUp(name, op_data);
713 }
714 
FindFunctionDef(const string & name) const715 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
716   return func_lib_def_.Find(name);
717 }
718 
ExportRunMetadata()719 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
720   mutex_lock ml(metadata_mu_);
721   auto result = std::make_unique<RunMetadata>();
722   run_metadata_.swap(result);
723   return result;
724 }
725 
UsesTFRT()726 bool EagerContext::UsesTFRT() { return false; }
727 
RunEagerOpAsFunction() const728 bool EagerContext::RunEagerOpAsFunction() const {
729   VLOG(3) << "RunEagerOpAsFunction: " << run_eager_op_as_function_;
730   return run_eager_op_as_function_;
731 }
732 
SetRunEagerOpAsFunction(bool enable)733 void EagerContext::SetRunEagerOpAsFunction(bool enable) {
734   run_eager_op_as_function_ = enable;
735 }
736 
JitCompileRewrite() const737 bool EagerContext::JitCompileRewrite() const {
738   VLOG(3) << "JitCompileRewrite: " << jit_compile_rewrite_;
739   return jit_compile_rewrite_;
740 }
741 
SetJitCompileRewrite(bool enable)742 void EagerContext::SetJitCompileRewrite(bool enable) {
743   jit_compile_rewrite_ = enable;
744 }
745 
ListDevices(std::vector<tensorflow::DeviceAttributes> * device_attributes)746 void EagerContext::ListDevices(
747     std::vector<tensorflow::DeviceAttributes>* device_attributes) {
748   std::vector<Device*> devices = ListAllTfDevices();
749   device_attributes->reserve(devices.size());
750   for (const auto& dev : devices) {
751     device_attributes->emplace_back(dev->attributes());
752   }
753 }
754 
ListAllTfDevices()755 std::vector<Device*> EagerContext::ListAllTfDevices() {
756   // Since remote_device_mgr may also contain local devices, make sure no
757   // duplicated device is returned.
758   std::vector<Device*> devices;
759   std::unordered_set<string> dev_names;
760 
761   if (local_device_mgr()) {
762     for (const auto& dev : local_device_mgr()->ListDevices()) {
763       devices.emplace_back(dev);
764       dev_names.emplace(dev->attributes().name());
765     }
766   }
767 
768   // TODO (b/197281777): Include local devices in remote_device_mgr on the
769   // client-side in single-client deployment.
770   if (remote_device_mgr()) {
771     for (const auto& dev : remote_device_mgr()->ListDevices()) {
772       Device* device = nullptr;
773       if (local_device_mgr()->LookupDevice(dev->name(), &device) !=
774           OkStatus()) {
775         // Include this device from remote_device_mgr only if it does not exist
776         // in local_device_mgr.
777         devices.emplace_back(dev);
778       }
779     }
780   }
781 
782   return devices;
783 }
784 
AddDevices(std::vector<std::unique_ptr<Device>> devices)785 Status EagerContext::AddDevices(std::vector<std::unique_ptr<Device>> devices) {
786   std::vector<std::unique_ptr<Device>> local_devices, remote_devices;
787   while (!devices.empty()) {
788     if (devices.front()->IsLocal()) {
789       local_devices.push_back(std::move(devices.front()));
790     } else {
791       remote_devices.push_back(std::move(devices.front()));
792     }
793     devices.erase(devices.begin());
794   }
795   TF_RETURN_IF_ERROR(
796       reinterpret_cast<DynamicDeviceMgr*>(local_device_manager_.Get())
797           ->AddDevices(std::move(local_devices)));
798 
799   if (!remote_devices.empty()) {
800     if (!remote_device_mgr()) {
801       remote_device_manager_.Reset(
802           std::make_unique<tensorflow::DynamicDeviceMgr>());
803     }
804     TF_RETURN_IF_ERROR(
805         reinterpret_cast<DynamicDeviceMgr*>(remote_device_manager_.Get())
806             ->AddDevices(std::move(remote_devices)));
807   }
808 
809   // Add the devices to pflr's device set.
810   pflr_->InitializeDeviceAndFlr();
811   InitPrioritizedDeviceTypeList();
812   return OkStatus();
813 }
814 
StartStep()815 void EagerContext::StartStep() {
816   mutex_lock ml(metadata_mu_);
817   num_active_steps_++;
818 }
819 
EndStep()820 void EagerContext::EndStep() {
821   mutex_lock ml(metadata_mu_);
822   num_active_steps_--;
823   if (num_active_steps_ == 0) {
824     // TODO(b/139809335): This does not properly clean up remote resources
825     // Clean up the previous step container and create a new one.
826     step_container_.reset(new ScopedStepContainer(
827         0, [this](const string& name) { ClearResourceContainer(name); }));
828   }
829 }
830 
StepContainer()831 ScopedStepContainer* EagerContext::StepContainer() {
832   mutex_lock ml(metadata_mu_);
833   return step_container_.get();
834 }
835 
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)836 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
837   // Only client context can register function on remote worker context.
838   if (!remote_device_manager_.Owned()) return OkStatus();
839 #if !defined(IS_MOBILE_PLATFORM)
840   std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
841   request->set_context_id(GetContextId());
842 
843   eager::RegisterFunctionOp* register_function =
844       request->add_queue()->mutable_register_function();
845   *register_function->mutable_function_def() = fdef;
846   StripDefaultAttributes(
847       *OpRegistry::Global(),
848       register_function->mutable_function_def()->mutable_node_def());
849 
850   auto remote_contexts = GetRemoteContexts();
851   for (const auto& target : remote_contexts) {
852     core::RefCountPtr<eager::EagerClient> eager_client;
853     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
854 
855     eager::EnqueueResponse* response = new eager::EnqueueResponse();
856     eager_client->StreamingEnqueueAsync(
857         this->Executor().StreamingEnqueue(),
858         /*call_opts=*/nullptr, request.get(), response,
859         [request, response](const Status& status) {
860           if (!status.ok()) {
861             LOG(ERROR) << "Failed to register function remotely due to "
862                        << status.error_message()
863                        << "\nThis could happen if the remote target has been "
864                           "disconnected from the client.";
865           }
866           delete response;
867         });
868   }
869 #endif  // !IS_MOBILE_PLATFORM
870   return OkStatus();
871 }
872 
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)873 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
874     const std::vector<string>& remote_workers) {
875 #if !defined(IS_MOBILE_PLATFORM)
876   // Register multiple functions on selected remote workers.
877   uint64 context_id = GetContextId();
878   FunctionDefLibrary function_defs = func_lib_def_.ToProto();
879   std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
880       function_defs.function_size());
881   for (int i = 0; i < function_defs.function_size(); i++) {
882     requests[i] = std::make_shared<eager::EnqueueRequest>();
883     requests[i]->set_context_id(context_id);
884     eager::RegisterFunctionOp* register_function =
885         requests[i]->add_queue()->mutable_register_function();
886     *register_function->mutable_function_def() =
887         std::move(*function_defs.mutable_function(i));
888     StripDefaultAttributes(
889         *OpRegistry::Global(),
890         register_function->mutable_function_def()->mutable_node_def());
891   }
892 
893   for (auto& remote_worker : remote_workers) {
894     core::RefCountPtr<eager::EagerClient> eager_client;
895     Status s = GetClient(remote_worker, &eager_client);
896     if (!s.ok()) {
897       continue;
898     }
899     for (int i = 0; i < requests.size(); i++) {
900       auto response = std::make_shared<eager::EnqueueResponse>();
901       eager_client->StreamingEnqueueAsync(
902           this->Executor().StreamingEnqueue(),
903           /*call_opts=*/nullptr, requests[i].get(), response.get(),
904           [request = requests[i], response](const Status& s) {
905             if (!s.ok()) {
906               LOG(ERROR) << "Failed to register function remotely due to "
907                          << s.error_message()
908                          << "\nThis could happen if the remote target has been "
909                             "disconnected from the client.";
910             }
911           });
912     }
913   }
914 #endif  // !IS_MOBILE_PLATFORM
915   return OkStatus();
916 }
917 
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)918 Status EagerContext::AddFunctionDefWithStackTraces(
919     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
920   return AddFunctionDef(fdef, FunctionDefLibrary(),
921                         /* add_to_local_only=*/false, stack_traces);
922 }
923 
AddFunctionDef(const FunctionDef & fdef)924 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
925   return AddFunctionDef(fdef, FunctionDefLibrary(),
926                         /* add_to_local_only=*/false);
927 }
928 
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)929 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
930                                     const FunctionDefLibrary& library,
931                                     const bool add_to_local_only,
932                                     const StackTracesMap& stack_traces) {
933   bool is_first_ref = false;
934   {
935     mutex_lock l(cache_mu_);
936     auto* registered_function =
937         gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
938     if (registered_function == nullptr) {
939       registered_function = new RegisteredFunction;
940       registered_function->cached_kernel_keys =
941           std::make_unique<std::vector<Fprint128>>();
942       gtl::InsertOrUpdate(&registered_functions_, fdef.signature().name(),
943                           registered_function);
944     } else {
945       // The function has been registered before. If the function is the same,
946       // then we take a Ref() otherwise we error out.
947       const FunctionDef* prev_fdef =
948           func_lib_def_.Find(fdef.signature().name());
949       if (prev_fdef == nullptr) {
950         return errors::Internal("Function: ", fdef.signature().name(),
951                                 " is in the cache but not in the library");
952       }
953       if (!FunctionDefsEqual(fdef, *prev_fdef)) {
954         return errors::InvalidArgument(
955             "Attempting to add a duplicate function with name: ",
956             fdef.signature().name(), " where the previous and current ",
957             "definitions differ. Previous definition: ",
958             prev_fdef->DebugString(),
959             " and current definition: ", fdef.DebugString());
960       }
961       registered_function->Ref();
962     }
963     is_first_ref = registered_function->RefCountIsOne();
964     if (is_first_ref) {
965       TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
966       TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
967     }
968   }
969   if (is_first_ref && !add_to_local_only) {
970     return MaybeRegisterFunctionRemotely(fdef);
971   }
972   return OkStatus();
973 }
974 
GetFunctionDef(const string & function_name)975 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
976   return func_lib_def_.Find(function_name);
977 }
978 
ListFunctionNames()979 std::vector<string> EagerContext::ListFunctionNames() {
980   return func_lib_def_.ListFunctionNames();
981 }
982 
RemoveFunction(const string & func)983 Status EagerContext::RemoveFunction(const string& func) {
984   // TODO(mdan): The context owns these functions. Why check refcount then?
985   mutex_lock l(cache_mu_);
986   auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
987   if (registered_function == nullptr) {
988     return errors::InvalidArgument("Tried to remove non-existent function '",
989                                    func, "'.");
990   }
991   bool is_last_ref = registered_function->RefCountIsOne();
992   if (is_last_ref) {
993     for (auto& key : *registered_function->cached_kernel_keys) {
994       kernel_cache_.erase(key);
995     }
996     registered_functions_.erase(func);
997   }
998   registered_function->Unref();
999   if (is_last_ref) {
1000     // TODO(fishx): Remove remote function as well.
1001     return func_lib_def_.RemoveFunction(func);
1002   }
1003   return OkStatus();
1004 }
1005 
SyncExecutors()1006 Status EagerContext::SyncExecutors() {
1007   VLOG(6) << "Calling SyncExecutors";
1008   StatusGroup sg;
1009   // Synchronize on context default executor
1010   sg.Update(default_executor_.WaitForAllPendingNodes());
1011   default_executor_.ClearError();
1012 
1013   // Synchronize thread local executors on client
1014   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
1015   {
1016     mutex_lock l(executor_map_mu_);
1017     executors_copy = thread_local_executor_;
1018   }
1019   for (const auto& entry : executors_copy) {
1020     sg.Update(entry.second->WaitForAllPendingNodes());
1021     entry.second->ClearError();
1022   }
1023 
1024 #if !defined(IS_MOBILE_PLATFORM)
1025   auto remote_contexts = GetRemoteContexts();
1026   // Synchronize executors on remote workers
1027   eager::EnqueueRequest request;
1028   request.set_context_id(GetContextId());
1029   request.add_queue()->mutable_sync_remote_executor_for_stream();
1030   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
1031   std::vector<Status> statuses(remote_contexts.size());
1032 
1033   for (int i = 0; i < remote_contexts.size(); i++) {
1034     const auto& target = remote_contexts[i];
1035     core::RefCountPtr<eager::EagerClient> eager_client;
1036     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
1037 
1038     eager::EnqueueResponse* response = new eager::EnqueueResponse();
1039     eager_client->StreamingEnqueueAsync(
1040         this->Executor().StreamingEnqueue(),
1041         /*call_opts=*/nullptr, &request, response,
1042         [response, target, &counter, &s = statuses[i]](const Status& status) {
1043           s = status;
1044           delete response;
1045           counter.DecrementCount();
1046         });
1047   }
1048   counter.Wait();
1049   for (const Status& s : statuses) {
1050     sg.Update(s);
1051   }
1052 #endif  // !IS_MOBILE_PLATFORM
1053 
1054   // Reset the global rendezvous, which otherwise stores a failure state.
1055   ResetGlobalRendezvousForFunction();
1056 
1057   return sg.as_summary_status();
1058 }
1059 
GetCachedKernel(Fprint128 cache_key)1060 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
1061     Fprint128 cache_key) {
1062   tf_shared_lock l(cache_mu_);
1063   auto iter = kernel_cache_.find(cache_key);
1064   if (iter == kernel_cache_.end()) {
1065     return nullptr;
1066   }
1067   core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
1068   new_ref->Ref();
1069   return new_ref;
1070 }
1071 
GetCachedDevice(Fprint128 device_cache_key)1072 Device* EagerContext::GetCachedDevice(Fprint128 device_cache_key) {
1073   tf_shared_lock l(device_cache_mu_);
1074   auto iter = device_cache_.find(device_cache_key);
1075   if (iter == device_cache_.end()) return nullptr;
1076   return iter->second;
1077 }
1078 
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)1079 void EagerContext::AddKernelToCache(Fprint128 cache_key,
1080                                     KernelAndDevice* kernel) {
1081   mutex_lock ml(cache_mu_);
1082   core::RefCountPtr<KernelAndDevice> new_ref(kernel);
1083   new_ref->Ref();
1084   kernel_cache_[cache_key] = std::move(new_ref);
1085   auto* registered_function =
1086       gtl::FindPtrOrNull(registered_functions_, kernel->name());
1087   // The kernel name can be either a primitive op or a function.
1088   if (registered_function != nullptr) {
1089     registered_function->cached_kernel_keys->emplace_back(cache_key);
1090   }
1091 }
1092 
AddDeviceToCache(Fprint128 device_cache_key,Device * device)1093 void EagerContext::AddDeviceToCache(Fprint128 device_cache_key,
1094                                     Device* device) {
1095   mutex_lock l(device_cache_mu_);
1096   device_cache_[device_cache_key] = device;
1097 }
1098 
ShouldStoreGraphs()1099 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
1100 
SetShouldStoreGraphs(bool value)1101 void EagerContext::SetShouldStoreGraphs(bool value) {
1102   mutex_lock ml(metadata_mu_);
1103   should_store_graphs_.store(value);
1104   if (!value) {
1105     run_metadata_.reset(new RunMetadata);
1106   }
1107 }
1108 
FindDeviceFromName(const char * device_name,Device ** device) const1109 Status EagerContext::FindDeviceFromName(const char* device_name,
1110                                         Device** device) const {
1111   *device = HostCPU();
1112   if (device_name == nullptr || strlen(device_name) == 0) {
1113     return OkStatus();
1114   }
1115 
1116   auto status = local_device_mgr()->LookupDevice(device_name, device);
1117   if (status.ok()) {
1118     return status;
1119   }
1120 
1121   if (remote_device_mgr() != nullptr) {
1122     return remote_device_mgr()->LookupDevice(device_name, device);
1123   }
1124 
1125   return status;
1126 }
1127 
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const1128 Status EagerContext::FindCompositeDeviceFromName(
1129     StringPiece device_name, CompositeDevice** device) const {
1130   tf_shared_lock l(composite_devices_mu_);
1131   for (const auto& d : composite_devices_) {
1132     if (d.second->name() == device_name) {
1133       *device = d.second.get();
1134       return OkStatus();
1135     }
1136   }
1137   return errors::NotFound("Unknown composite device: ", device_name);
1138 }
1139 
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)1140 Status EagerContext::RegisterCustomDevice(
1141     const string& device_name, std::unique_ptr<CustomDevice> device) {
1142   Device* existing_physical_device = nullptr;
1143   if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
1144     return errors::AlreadyExists(device_name,
1145                                  " already registered as a physical device.");
1146   }
1147   return custom_device_op_handler_.RegisterCustomDevice(device_name,
1148                                                         std::move(device));
1149 }
1150 
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)1151 Status EagerContext::FindOrCreateCompositeDevice(
1152     const std::vector<string>& underlying_devices, const string& device_name,
1153     CompositeDevice** composite_device) {
1154   if (!device_name.empty() &&
1155       FindCompositeDeviceFromName(device_name, composite_device).ok()) {
1156     return OkStatus();
1157   }
1158 
1159   const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
1160 
1161   mutex_lock l(composite_devices_mu_);
1162   auto iter = composite_devices_.find(hash_key);
1163   if (iter != composite_devices_.end()) {
1164     *composite_device = iter->second.get();
1165     return OkStatus();
1166   }
1167 
1168   Status s;
1169   std::unique_ptr<CompositeDevice> device;
1170   if (device_name.empty()) {
1171     // Create a CompositeDevice on the same task as the host CPU, in order to
1172     // trigger packed TensorHandle copy from a client to a remote worker.
1173     device = CompositeDevice::MakeDevice(underlying_devices,
1174                                          composite_devices_.size(),
1175                                          HostCPU()->parsed_name(), &s);
1176   } else {
1177     device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
1178   }
1179   TF_RETURN_IF_ERROR(s);
1180   *composite_device = device.get();
1181   pflr_->AddCompositeDevice(*composite_device);
1182   composite_devices_.emplace(hash_key, std::move(device));
1183   return OkStatus();
1184 }
1185 
OnSameTask(const Device * first,const Device * second) const1186 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
1187   if (first == nullptr) first = HostCPU();
1188   if (second == nullptr) second = HostCPU();
1189   return first->parsed_name().job == second->parsed_name().job &&
1190          first->parsed_name().replica == second->parsed_name().replica &&
1191          first->parsed_name().task == second->parsed_name().task;
1192 }
1193 
1194 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const1195 Status EagerContext::CPUDeviceOnTask(const Device* device,
1196                                      Device** cpu_device) const {
1197   string cpu_device_name;
1198   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
1199       device->name(), &cpu_device_name));
1200 
1201   return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
1202 }
1203 
ClearResourceContainer(const string & name)1204 void EagerContext::ClearResourceContainer(const string& name) {
1205   // TODO(b/139809335): This does not properly clean up remote resources
1206   auto local_devices = local_device_mgr()->ListDevices();
1207   for (Device* device : local_devices) {
1208     // Only ignore container not found errors.
1209     device->resource_manager()->Cleanup(name).IgnoreError();
1210   }
1211 }
1212 
GetGlobalRendezvousForFunctionLocalRendezvousStatus()1213 Status EagerContext::GetGlobalRendezvousForFunctionLocalRendezvousStatus() {
1214   mutex_lock l(global_rendezvous_mu_);
1215   IntraProcessRendezvous* rendezvous =
1216       local_rendezvous_table_->Find(kGlobalRendezvousId);
1217   if (rendezvous == nullptr) return OkStatus();
1218   Status s = rendezvous->GetLocalRendezvousStatus();
1219   rendezvous->Unref();
1220   return s;
1221 }
1222 
UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr * device_mgr)1223 void EagerContext::UpdateGlobalRendezvousDeviceManager(
1224     tensorflow::DeviceMgr* device_mgr) {
1225   mutex_lock l(global_rendezvous_mu_);
1226   IntraProcessRendezvous* rendezvous =
1227       local_rendezvous_table_->Find(kGlobalRendezvousId);
1228   if (rendezvous == nullptr) return;
1229   rendezvous->UpdateDeviceManager(device_mgr);
1230   rendezvous->Unref();
1231 }
1232 
1233 namespace {
GetTaskName(Device * d,string * task_name)1234 Status GetTaskName(Device* d, string* task_name) {
1235   string ignored;
1236   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
1237     return errors::InvalidArgument("Unable to parse device name: ", d->name());
1238   }
1239 
1240   return OkStatus();
1241 }
1242 }  // namespace
1243 
1244 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)1245 Status EagerContext::GetClient(Device* device,
1246                                core::RefCountPtr<eager::EagerClient>* client) {
1247   return GetClient(device->parsed_name(), client);
1248 }
1249 
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)1250 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
1251                                core::RefCountPtr<eager::EagerClient>* client) {
1252   string device_task_name;
1253   if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1254     return errors::InvalidArgument(
1255         "Task is not fully specified in device name: ",
1256         DeviceNameUtils::ParsedNameToString(device_name));
1257   }
1258 
1259   {
1260     tf_shared_lock l(remote_state_mu_);
1261     if (remote_eager_workers_ == nullptr) {
1262       return errors::Internal(
1263           "Haven't set up remote eager worker in this eager context yet.");
1264     }
1265     TF_RETURN_IF_ERROR(
1266         remote_eager_workers_->GetClient(device_task_name, client));
1267 
1268     if (*client == nullptr) {
1269       return errors::InvalidArgument(
1270           "Unable to find eager client corresponding to device ",
1271           DeviceNameUtils::ParsedNameToString(device_name));
1272     }
1273     if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1274                   device_task_name) == remote_contexts_.end()) {
1275       return errors::Internal("Unable to find a context for handle on task: ",
1276                               device_task_name, ". This should not happen.");
1277     }
1278   }
1279 
1280   return OkStatus();
1281 }
1282 
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1283 Status EagerContext::GetClient(const string& remote_task,
1284                                core::RefCountPtr<eager::EagerClient>* client) {
1285   {
1286     tf_shared_lock l(remote_state_mu_);
1287     if (remote_eager_workers_ == nullptr) {
1288       return errors::Internal(
1289           "Haven't set up remote eager worker in this eager context yet.");
1290     }
1291     TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1292   }
1293 
1294   if (*client == nullptr) {
1295     return errors::InvalidArgument(
1296         "Unable to find eager client corresponding to target ", remote_task);
1297   }
1298   return OkStatus();
1299 }
1300 
GetContextId() const1301 uint64 EagerContext::GetContextId() const {
1302   tf_shared_lock l(remote_state_mu_);
1303   return context_id_;
1304 }
1305 
GetContextViewId() const1306 uint64 EagerContext::GetContextViewId() const {
1307   tf_shared_lock l(remote_state_mu_);
1308   return context_view_id_;
1309 }
1310 
IncrementContextViewId()1311 void EagerContext::IncrementContextViewId() {
1312   mutex_lock l(remote_state_mu_);
1313   context_view_id_ += 1;
1314 }
1315 
EnableCollectiveOps(const ServerDef & server_def)1316 Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) {
1317   return distributed_manager_->EnableCollectiveOps(server_def);
1318 }
1319 
1320 // Set collective ops related state in the context. Passing nullptr to
1321 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1322 Status EagerContext::StoreCollectiveOpsServer(
1323     std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
1324     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1325   collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1326 
1327   if (device_mgr != local_device_manager_.Get()) {
1328     if (local_device_manager_.Owned()) {
1329       old_local_device_managers_.push_back(
1330           std::move(local_device_manager_.owned_object));
1331     }
1332     local_device_manager_.Reset(device_mgr);
1333     UpdateGlobalRendezvousDeviceManager(local_device_manager_.Get());
1334     if (rendezvous_ != nullptr) rendezvous_->Unref();
1335     rendezvous_ = CreateRendezvous(-1);
1336   }
1337   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1338 
1339   InitPrioritizedDeviceTypeList();
1340   ClearCachesAndThreadExecutors();
1341   default_executor_.ClearError();
1342   {
1343     tensorflow::mutex_lock l(executor_map_mu_);
1344     for (auto& entry : thread_local_executor_) {
1345       entry.second->ClearError();
1346     }
1347   }
1348 
1349   const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1350   ResetPFLR(
1351       local_device_manager_.Get(), env_, /*config=*/config,
1352       TF_GRAPH_DEF_VERSION, &func_lib_def_,
1353       /*optimizer_options=*/
1354       config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1355       thread_pool_.get());
1356 
1357   if (new_server != nullptr) {
1358     // Memory leak!
1359     if (server_ != nullptr) {
1360       LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1361                       "Servers don't support clean shutdown.";
1362       server_.release();
1363     }
1364     server_ = std::move(new_server);
1365   }
1366   DCHECK(server_ != nullptr);
1367 
1368   return OkStatus();
1369 }
1370 
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1371 Status EagerContext::SetRemoteDeviceFilters(
1372     const string& remote_worker, const std::vector<string>& device_filters) {
1373   // Get fully specified task name for remote worker
1374   string remote_worker_task_name;
1375   DeviceNameUtils::ParsedName pw;
1376   if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1377     return tensorflow::errors::InvalidArgument(
1378         "Remote worker task name is invalid ", remote_worker);
1379   }
1380   // Force set a replica as the key in cluster device filters map. I.e., if the
1381   // remote worker is `/job:worker/task:0` it then becomes
1382   // `/job:worker/replica:0/task:0`.
1383   pw.has_replica = true;
1384   if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1385     return tensorflow::errors::InvalidArgument(
1386         "Job name and task index must be specified for worker ", remote_worker);
1387   }
1388 
1389   std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1390   for (auto& filter : device_filters) {
1391     DeviceNameUtils::ParsedName parsed_filter;
1392     if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1393       parsed_filters.emplace_back(parsed_filter);
1394     } else {
1395       return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1396     }
1397   }
1398 
1399   if (VLOG_IS_ON(1)) {
1400     VLOG(1) << "Setting device filters for " << remote_worker << ":";
1401     for (auto& filter : device_filters) {
1402       VLOG(1) << "  " << filter;
1403     }
1404   }
1405   mutex_lock l(remote_state_mu_);
1406   cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1407   return OkStatus();
1408 }
1409 
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1410 void EagerContext::FilterDevicesForRemoteWorkers(
1411     const string& remote_worker,
1412     const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1413     std::vector<bool>* filtered_device_mask) {
1414   filtered_device_mask->resize(device_attrs.size());
1415   std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1416 
1417   tf_shared_lock l(remote_state_mu_);
1418   auto it = cluster_device_filters_.find(remote_worker);
1419   // If no filters were specified, all devices should be visible to the worker
1420   if (it == cluster_device_filters_.end() || it->second.empty()) {
1421     std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1422     return;
1423   }
1424 
1425   const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1426   DeviceNameUtils::ParsedName parsed_remote_worker;
1427   DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1428   for (int i = 0; i < device_attrs.size(); i++) {
1429     DeviceNameUtils::ParsedName pn;
1430     DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1431     if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1432       // If this device is on the remote worker itself, it should be visible
1433       // regardless of device filters
1434       filtered_device_mask->at(i) = true;
1435       continue;
1436     }
1437     for (const auto& pf : parsed_filters) {
1438       if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1439           (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1440           (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1441           (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1442           (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1443         // Found a match, make it visible, stop processing more device filters
1444         filtered_device_mask->at(i) = true;
1445         break;
1446       }
1447     }
1448   }
1449 }
1450 
SetWorkerEnv(WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session)1451 void EagerContext::SetWorkerEnv(WorkerEnv* worker_env,
1452                                 std::shared_ptr<WorkerSession> worker_session) {
1453   worker_env_ = worker_env;
1454   worker_session_ = worker_session;
1455 }
1456 
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1457 Status EagerContext::InitializeRemoteMaster(
1458     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1459     std::shared_ptr<WorkerSession> worker_session,
1460     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1461     std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1462     const std::vector<string>& remote_contexts, uint64 context_id,
1463     Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
1464     DistributedFunctionLibraryRuntime* cluster_flr,
1465     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1466         remote_mgr) {
1467   if (context_id == kInvalidContextId) {
1468     return errors::InvalidArgument(
1469         "Failed to initialize remote for master context due to invalid ",
1470         "context id");
1471   }
1472 
1473   if (!IsRemoteContextsEmpty()) {
1474     CloseAndClearAllRemoteContexts();
1475   }
1476   {
1477     mutex_lock l(remote_state_mu_);
1478     remote_contexts_ = remote_contexts;
1479   }
1480 
1481   return SetMasterContextState(
1482       std::move(server), worker_env, std::move(worker_session),
1483       std::move(remote_eager_workers), std::move(remote_device_manager),
1484       context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1485       std::move(remote_mgr));
1486 }
1487 
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1488 Status EagerContext::UpdateRemoteMaster(
1489     uint64 context_id,
1490     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1491     const std::vector<string>& add_remote_contexts,
1492     const std::vector<string>& remove_remote_contexts) {
1493   {
1494     tf_shared_lock l(remote_state_mu_);
1495     if (context_id != context_id_) {
1496       return errors::InvalidArgument(
1497           "Failed to update remote master context due to invalid context id. ",
1498           "Request id = ", context_id, " but current id = ", context_id_);
1499     }
1500   }
1501 
1502   if (!remove_remote_contexts.empty()) {
1503     // N.B. remove_remote_contexts include both removed and replaced workers.
1504     // In the case where a worker is replaced by one that resolves to the same
1505     // `hostname:port`, it is safe to close context with the current view id,
1506     // since the newly created context on the remote worker will be holding
1507     // a larger view id and ignores this request.
1508     CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1509     mutex_lock l(remote_state_mu_);
1510     for (const string& remote_context : remove_remote_contexts) {
1511       remote_contexts_.erase(
1512           std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1513                       remote_context),
1514           remote_contexts_.end());
1515     }
1516   }
1517   if (!add_remote_contexts.empty()) {
1518     mutex_lock l(remote_state_mu_);
1519     remote_contexts_.insert(std::end(remote_contexts_),
1520                             std::begin(add_remote_contexts),
1521                             std::end(add_remote_contexts));
1522   }
1523 
1524   {
1525     mutex_lock l(remote_state_mu_);
1526     context_view_id_++;
1527 
1528     remote_eager_workers_ = std::move(remote_eager_workers);
1529     pflr_->InitializeDeviceAndFlr();
1530     InitPrioritizedDeviceTypeList();
1531 
1532     default_executor_.ClearError();
1533     {
1534       tensorflow::mutex_lock l(executor_map_mu_);
1535       for (auto& entry : thread_local_executor_) {
1536         entry.second->ClearError();
1537       }
1538     }
1539   }
1540 
1541   // Register existing functions to the newly added remote workers. Note that
1542   // this should happen only after updating `remote_contexts_` because new
1543   // functions might be registered while we update the context. When that
1544   // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1545   // register the new functions on all remote workers (including the newly added
1546   // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1547   // registering existing functions, where duplicate registrations will be
1548   // ignored by the remote workers.
1549   TF_RETURN_IF_ERROR(
1550       RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1551   return OkStatus();
1552 }
1553 
1554 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1555 Status EagerContext::SetMasterContextState(
1556     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1557     std::shared_ptr<WorkerSession> worker_session,
1558     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1559     std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1560     uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1561     int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1562     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1563         remote_mgr) {
1564   mutex_lock l(remote_state_mu_);
1565   is_master_ = true;
1566   context_id_ = context_id;
1567   context_view_id_ = context_view_id;
1568 
1569   use_send_tensor_rpc_ =
1570       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1571 
1572   if (local_device_mgr != local_device_manager_.Get()) {
1573     if (local_device_manager_.Owned()) {
1574       old_local_device_managers_.push_back(
1575           std::move(local_device_manager_.owned_object));
1576     }
1577     local_device_manager_.Reset(local_device_mgr);
1578     UpdateGlobalRendezvousDeviceManager(local_device_manager_.Get());
1579   }
1580   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1581 
1582   if (rendezvous_ != nullptr) rendezvous_->Unref();
1583   rendezvous_ = r;
1584 
1585   // Memory leak!
1586   if (server_ != nullptr) {
1587     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1588                     "Servers don't support clean shutdown.";
1589     server_.release();
1590   }
1591   server_ = std::move(server);
1592 
1593   remote_mgr_ = std::move(remote_mgr);
1594   worker_env_ = worker_env;
1595   worker_session_ = std::move(worker_session);
1596   remote_eager_workers_ = std::move(remote_eager_workers);
1597 
1598   remote_device_manager_.Reset(std::move(remote_device_manager));
1599   ResetClusterFLR(cluster_flr);
1600 
1601   InitPrioritizedDeviceTypeList();
1602 
1603   ClearCachesAndThreadExecutors();
1604   default_executor_.ClearError();
1605   {
1606     tensorflow::mutex_lock l(executor_map_mu_);
1607     for (auto& entry : thread_local_executor_) {
1608       entry.second->ClearError();
1609     }
1610   }
1611   const auto* config = pflr_->config();
1612   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1613             &func_lib_def_, config->graph_options().optimizer_options(),
1614             thread_pool_.get(), cluster_flr_.Get());
1615 
1616   keep_alive_secs_ = keep_alive_secs;
1617   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1618   // Only schedule a single closure.
1619   if (keep_alive_thread_ == nullptr) {
1620     keep_alive_thread_.reset(
1621         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1622           while (true) {
1623             {
1624               {
1625                 mutex_lock l(keep_alive_thread_shutdown_mu_);
1626 
1627                 if (shutting_down_) {
1628                   return;
1629                 }
1630 
1631                 keep_alive_thread_cv_.wait_for(
1632                     l, std::chrono::seconds(sleep_for_secs_));
1633 
1634                 if (shutting_down_) {
1635                   return;
1636                 }
1637               }
1638               {
1639                 mutex_lock l(remote_state_mu_);
1640                 if (keep_alive_secs_ > 0) {
1641                   {
1642                     for (const auto& worker : remote_contexts_) {
1643                       core::RefCountPtr<eager::EagerClient> client;
1644                       Status s =
1645                           remote_eager_workers_->GetClient(worker, &client);
1646 
1647                       if (!s.ok()) {
1648                         LOG(WARNING) << "Keep-alive thread was unable to find "
1649                                         "a client for target "
1650                                      << worker << ". Got error: " << s;
1651                         continue;
1652                       }
1653 
1654                       eager::KeepAliveRequest* request =
1655                           new eager::KeepAliveRequest;
1656                       eager::KeepAliveResponse* response =
1657                           new eager::KeepAliveResponse;
1658 
1659                       request->set_context_id(context_id_);
1660                       client->KeepAliveAsync(
1661                           request, response,
1662                           [request, response](const Status& s) {
1663                             delete request;
1664                             delete response;
1665                           });
1666                     }
1667                   }
1668                 }
1669               }
1670             }
1671           }
1672         }));
1673   }
1674   return OkStatus();
1675 }
1676 
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64_t)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1677 Status EagerContext::InitializeRemoteWorker(
1678     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1679     DynamicDeviceMgr* remote_device_mgr,
1680     const std::vector<string>& remote_contexts, uint64 context_id,
1681     uint64 context_view_id,
1682     std::function<Rendezvous*(const int64_t)> rendezvous_creator,
1683     DistributedFunctionLibraryRuntime* cluster_flr,
1684     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1685         remote_mgr,
1686     std::function<void()> resource_deallocator) {
1687   if (context_id == kInvalidContextId) {
1688     return errors::InvalidArgument(
1689         "Failed to initialize remote for worker context due to invalid ",
1690         "context id");
1691   }
1692   mutex_lock l(remote_state_mu_);
1693 
1694   if (remote_device_manager_.Owned() || server_ != nullptr ||
1695       keep_alive_thread_ != nullptr) {
1696     return errors::FailedPrecondition(
1697         "EagerContext::InitializeRemoteWorker Failed. ",
1698         "Already initialized remote as a master context.");
1699   }
1700   is_master_ = false;
1701 
1702   remote_contexts_ = remote_contexts;
1703   context_id_ = context_id;
1704   context_view_id_ = context_view_id;
1705 
1706   rendezvous_creator_ = std::move(rendezvous_creator);
1707   remote_eager_workers_ = std::move(remote_eager_workers);
1708   remote_mgr_ = std::move(remote_mgr);
1709   ResetClusterFLR(cluster_flr);
1710 
1711   remote_device_manager_.Reset(remote_device_mgr);
1712 
1713   const auto* config = pflr_->config();
1714   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1715             &func_lib_def_, config->graph_options().optimizer_options(),
1716             thread_pool_.get(), cluster_flr_.Get());
1717   InitPrioritizedDeviceTypeList();
1718 
1719   ClearCachesAndThreadExecutors();
1720   default_executor_.ClearError();
1721   {
1722     tensorflow::mutex_lock l(executor_map_mu_);
1723     for (auto& entry : thread_local_executor_) {
1724       entry.second->ClearError();
1725     }
1726   }
1727 
1728   resource_deallocator_ = std::move(resource_deallocator);
1729 
1730   return OkStatus();
1731 }
1732 
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1733 Status EagerContext::UpdateRemoteWorker(
1734     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1735     const std::vector<string>& remote_contexts, uint64 context_id) {
1736   {
1737     mutex_lock l(remote_state_mu_);
1738     if (context_id != context_id_) {
1739       return errors::InvalidArgument(
1740           "Failed to update remote for worker context due to invalid ",
1741           "context id. Request id = ", context_id,
1742           " but current id = ", context_id_);
1743     }
1744     context_view_id_++;
1745 
1746     remote_contexts_ = remote_contexts;
1747     remote_eager_workers_ = std::move(remote_eager_workers);
1748     InitPrioritizedDeviceTypeList();
1749     pflr_->InitializeDeviceAndFlr();
1750   }
1751 
1752   // No need to update remote_device_manager_ since it's not owned for remote
1753   // worker context (owned by the corresponding worker session).
1754   if (remote_device_manager_.Owned()) {
1755     return errors::FailedPrecondition(
1756         "EagerContext::UpdateRemoteWorker failed because the context was "
1757         "initialized as a master context.");
1758   }
1759 
1760   ClearCachesAndThreadExecutors();
1761   default_executor_.ClearError();
1762   {
1763     tensorflow::mutex_lock l(executor_map_mu_);
1764     for (auto& entry : thread_local_executor_) {
1765       entry.second->ClearError();
1766     }
1767   }
1768   return OkStatus();
1769 }
1770 #endif  // !IS_MOBILE_PLATFORM
1771 
1772 }  // namespace tensorflow
1773