• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/platform.h"
29 // clang-format on
30 
31 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
32 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
33 #include "tensorflow/core/common_runtime/colocation_graph.h"
34 #include "tensorflow/core/common_runtime/device_resolver_local.h"
35 #include "tensorflow/core/common_runtime/device_set.h"
36 #include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
37 #include "tensorflow/core/common_runtime/process_util.h"
38 #include "tensorflow/core/framework/graph_def_util.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #if !defined(IS_MOBILE_PLATFORM)
44 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
45 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
46 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
47 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
48 #endif  // !IS_MOBILE_PLATFORM
49 #include "tensorflow/core/framework/resource_mgr.h"
50 #include "tensorflow/core/lib/core/blocking_counter.h"
51 #include "tensorflow/core/lib/monitoring/gauge.h"
52 #include "tensorflow/core/platform/monitoring.h"
53 #include "tensorflow/core/util/env_var.h"
54 
55 namespace tensorflow {
56 namespace {
57 
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)58 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
59   bool val;
60   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
61     return val;
62   }
63   return default_val;
64 }
65 
66 auto* eager_context_created =
67     monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
68                                     "True if an eager context was created.");
69 
70 }  // namespace
71 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,ContextMirroringPolicy default_mirroring_policy,bool async,const bool lazy_copy_function_remote_inputs,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,const CustomKernelCreator * custom_kernel_creator,DistributedFunctionLibraryRuntime * cluster_flr)72 EagerContext::EagerContext(
73     const SessionOptions& opts,
74     ContextDevicePlacementPolicy default_device_placement_policy,
75     ContextMirroringPolicy default_mirroring_policy, bool async,
76     const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
77     bool device_mgr_owned, Rendezvous* rendezvous,
78     const CustomKernelCreator* custom_kernel_creator,
79     DistributedFunctionLibraryRuntime* cluster_flr)
80     : default_device_placement_policy_(default_device_placement_policy),
81       default_mirroring_policy_(default_mirroring_policy),
82       local_device_manager_(device_mgr, device_mgr_owned),
83       host_cpu_device_(device_mgr->HostCPU()),
84       rendezvous_(rendezvous),
85       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
86       custom_kernel_creator_(custom_kernel_creator),
87       cluster_flr_(cluster_flr),
88       log_device_placement_(opts.config.log_device_placement()),
89       allow_soft_placement_(opts.config.allow_soft_placement()),
90       num_active_steps_(0),
91       default_executor_(async),
92       log_memory_(LogMemory::IsEnabled()),
93       env_(opts.env),
94       lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
95       use_send_tensor_rpc_(false),
96       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
97           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
98   ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
99             &func_lib_def_, opts.config.graph_options().optimizer_options(),
100             thread_pool_.get(), cluster_flr, custom_kernel_creator_);
101   // Starts exporting metrics through a platform-specific monitoring API (if
102   // provided). For builds using "tensorflow/core/platform/default", this is
103   // currently a no-op.
104   eager_context_created->GetCell()->Set(true);
105   monitoring::StartExporter();
106   InitPrioritizedDeviceTypeList();
107   runner_ = [this](std::function<void()> closure) {
108     this->thread_pool_->Schedule(std::move(closure));
109   };
110 
111 #if !defined(IS_MOBILE_PLATFORM)
112   context_id_ = kInvalidContextId;
113 #endif  // IS_MOBILE_PLATFORM
114 
115   std::unique_ptr<DeviceResolverInterface> drl(
116       new DeviceResolverLocal(local_device_mgr()));
117   std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
118       opts.config, local_device_mgr(), drl.get(),
119       "/job:localhost/replica:0/task:0"));
120   collective_executor_mgr_.Reset(
121       new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
122                                 std::move(cprl)),
123       /*owned=*/true);
124 }
125 
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr,const CustomKernelCreator * custom_kernel_creator)126 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
127                              const ConfigProto* config, int graph_def_version,
128                              const FunctionLibraryDefinition* lib_def,
129                              const OptimizerOptions& optimizer_options,
130                              thread::ThreadPool* thread_pool,
131                              DistributedFunctionLibraryRuntime* cluster_flr,
132                              const CustomKernelCreator* custom_kernel_creator) {
133   if (lazy_copy_function_remote_inputs_) {
134     pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
135         device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
136         thread_pool, cluster_flr, custom_kernel_creator));
137   } else {
138     pflr_.reset(new ProcessFunctionLibraryRuntime(
139         device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
140         thread_pool, cluster_flr, custom_kernel_creator));
141   }
142 }
143 
InitPrioritizedDeviceTypeList()144 void EagerContext::InitPrioritizedDeviceTypeList() {
145   DeviceSet ds;
146   for (Device* d : local_device_mgr()->ListDevices()) {
147     ds.AddDevice(d);
148   }
149   auto remote_device_manager = remote_device_mgr();
150   if (remote_device_manager != nullptr) {
151     for (Device* d : remote_device_manager->ListDevices()) {
152       ds.AddDevice(d);
153     }
154   }
155   prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
156 }
157 
158 namespace {
159 // Using absl::StrJoin with lambda does not work in tf-lite builds.
160 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const std::vector<Device * > & devices)161 std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
162   std::vector<string> v;
163   v.reserve(devices.size());
164   for (Device* d : devices) {
165     v.push_back(d->name());
166   }
167   return v;
168 }
169 }  // namespace
170 
SelectDevice(DeviceNameUtils::ParsedName preferred,const PrioritizedDeviceTypeVector & supported,const DataType dtype,Device ** device) const171 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
172                                   const PrioritizedDeviceTypeVector& supported,
173                                   const DataType dtype, Device** device) const {
174   std::vector<Device*> selected;
175   const DeviceSet& pflr_devices = *pflr()->device_set();
176 
177   // We always place string tensors on the CPU device if we're allowed to.
178   if (dtype == DT_STRING && AllowSoftPlacement()) {
179     preferred = HostCPU()->parsed_name();
180   }
181 
182   // If there are no preferred devices, select the first registered device from
183   // the supported device list.
184   if (!DeviceNameUtils::HasSomeDetails(preferred)) {
185     // TODO(b/148213212): Allow setting default device in eager context.
186     selected = ColocationGraph::FilterSupportedDevices(
187         pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
188     if (selected.empty()) {
189       return errors::InvalidArgument(
190           "No supported device found in available devices [",
191           absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
192     }
193     *device = selected[0];
194     return Status::OK();
195   }
196 
197   // If the caller specified a preferred device, select the first matching
198   // registered device from the supported device list. If nothing matches and
199   // soft placement is enabled, pick a suitable device from the available ones.
200   pflr_devices.FindMatchingDevices(preferred, &selected);
201 
202   if (!selected.empty()) {
203     selected = ColocationGraph::FilterSupportedDevices(
204         selected, supported, /*default_local_device=*/nullptr);
205   }
206 
207   if (selected.empty() && AllowSoftPlacement()) {
208     DeviceNameUtils::ParsedName soft_device_name = preferred;
209     soft_device_name.type.clear();
210     soft_device_name.has_type = false;
211     soft_device_name.has_id = false;
212     // TODO(b/148213746): Soft placement logic picks up another task if the
213     // requested does not exist.
214     pflr_devices.FindMatchingDevices(soft_device_name, &selected);
215     if (!selected.empty()) {
216       selected = ColocationGraph::FilterSupportedDevices(
217           selected, supported, /*default_local_device=*/nullptr);
218     }
219   }
220 
221   if (selected.empty()) {
222     return errors::InvalidArgument(
223         "Could not satisfy device specification '", preferred,
224         "'. All available devices [",
225         absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
226   }
227 
228   *device = selected[0];
229   return Status::OK();
230 }
231 
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)232 void EagerContext::ResetClusterFLR(
233     DistributedFunctionLibraryRuntime* cluster_flr) {
234   cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
235 }
236 
Executor()237 EagerExecutor& EagerContext::Executor() {
238   tf_shared_lock l(executor_map_mu_);
239   return *gtl::FindWithDefault(thread_local_executor_,
240                                std::this_thread::get_id(), &default_executor_);
241 }
242 
SetExecutorForThread(EagerExecutor * executor)243 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
244   tensorflow::mutex_lock l(executor_map_mu_);
245   if (executor == &default_executor_) {
246     thread_local_executor_.erase(std::this_thread::get_id());
247   } else {
248     thread_local_executor_[std::this_thread::get_id()] = executor;
249   }
250 }
251 
ClearCachesAndThreadExecutors()252 void EagerContext::ClearCachesAndThreadExecutors() {
253   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
254   {
255     mutex_lock l(executor_map_mu_);
256     executors_copy = thread_local_executor_;
257   }
258   for (const auto& entry : executors_copy) {
259     entry.second->WaitForAllPendingNodes().IgnoreError();
260   }
261   ClearCachesAndDefaultExecutor();
262 }
263 
ClearCachesAndDefaultExecutor()264 void EagerContext::ClearCachesAndDefaultExecutor() {
265   // The executor stores pointers to kernels, so we need to make sure that no
266   // async eager ops are still executing. We lock the cache during this time
267   // as well.
268   mutex_lock ml(cache_mu_);
269   default_executor_.WaitForAllPendingNodes().IgnoreError();
270   kernel_cache_.clear();
271   for (auto& entry : registered_functions_) {
272     entry.second->cached_kernel_keys->clear();
273   }
274 }
275 
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)276 void EagerContext::SetThreadLocalDevicePlacementPolicy(
277     ContextDevicePlacementPolicy policy) {
278   mutex_lock ml(policy_map_mu_);
279   device_placement_policy_[std::this_thread::get_id()] = policy;
280 }
281 
GetDevicePlacementPolicy() const282 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
283   tf_shared_lock l(policy_map_mu_);
284   auto policy_map_it =
285       device_placement_policy_.find(std::this_thread::get_id());
286   if (policy_map_it != device_placement_policy_.end()) {
287     return policy_map_it->second;
288   }
289   return default_device_placement_policy_;
290 }
291 
SetThreadLocalMirroringPolicy(ContextMirroringPolicy policy)292 void EagerContext::SetThreadLocalMirroringPolicy(
293     ContextMirroringPolicy policy) {
294   mutex_lock ml(policy_map_mu_);
295   mirroring_policy_[std::this_thread::get_id()] = policy;
296 }
297 
GetMirroringPolicy() const298 ContextMirroringPolicy EagerContext::GetMirroringPolicy() const {
299   tf_shared_lock l(policy_map_mu_);
300   auto policy_map_it = mirroring_policy_.find(std::this_thread::get_id());
301   if (policy_map_it != mirroring_policy_.end()) {
302     return policy_map_it->second;
303   }
304   return default_mirroring_policy_;
305 }
306 
MirrorTensors() const307 bool EagerContext::MirrorTensors() const {
308   return GetMirroringPolicy() == MIRRORING_ALL;
309 }
310 
LazyCopyFunctionRemoteInputs() const311 bool EagerContext::LazyCopyFunctionRemoteInputs() const {
312   return lazy_copy_function_remote_inputs_;
313 }
314 
315 #if !defined(IS_MOBILE_PLATFORM)
CloseAndClearAllRemoteContexts()316 void EagerContext::CloseAndClearAllRemoteContexts() {
317   uint64 context_id;
318   uint64 context_view_id;
319   {
320     mutex_lock l(remote_state_mu_);
321     if (!is_master_) return;
322     context_id = context_id_;
323     context_view_id = context_view_id_;
324     context_id_ = kInvalidContextId;
325     // Forget the current view id and reset to the starting value 0.
326     context_view_id_ = 0;
327   }
328   CloseRemoteContexts(remote_contexts_, context_id, context_view_id);
329   remote_contexts_.clear();
330 }
331 
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)332 void EagerContext::CloseRemoteContexts(
333     const std::vector<string>& remote_contexts, uint64 context_id,
334     uint64 context_view_id) {
335   // Close all remote contexts.
336   eager::CloseContextRequest request;
337   request.set_context_id(context_id);
338   request.set_context_view_id(context_view_id);
339   // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
340   // request to closed remote workers.
341   std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
342   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
343 
344   int i = 0;
345   for (const auto& worker : remote_contexts) {
346     core::RefCountPtr<eager::EagerClient> client;
347     Status s = remote_eager_workers_->GetClient(worker, &client);
348 
349     client->CloseContextAsync(
350         &request, &responses[i],
351         [&worker, &counter, context_id](const Status& s) {
352           if (!s.ok()) {
353             LOG(ERROR) << "Unable to close remote context with ID "
354                        << context_id << " for worker: " << worker << " due to "
355                        << s.error_message();
356           }
357           counter.DecrementCount();
358         });
359     i++;
360   }
361 
362   counter.Wait();
363 }
364 
365 #endif  // !IS_MOBILE_PLATFORM
366 
WaitForAndCloseRemoteContexts()367 void EagerContext::WaitForAndCloseRemoteContexts() {
368   ClearCachesAndThreadExecutors();
369 
370 #if !defined(IS_MOBILE_PLATFORM)
371   {
372     mutex_lock l(keep_alive_thread_shutdown_mu_);
373     shutting_down_ = true;
374     keep_alive_thread_cv_.notify_all();
375   }
376   keep_alive_thread_.reset();
377 
378   if (!remote_contexts_.empty()) {
379     CloseAndClearAllRemoteContexts();
380   }
381 
382   {
383     mutex_lock l(remote_state_mu_);
384 
385     default_executor_.ShutDown().IgnoreError();
386     std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
387     {
388       mutex_lock l(executor_map_mu_);
389       executors_copy = thread_local_executor_;
390     }
391     for (const auto& it : executors_copy) {
392       it.second->ShutDown().IgnoreError();
393     }
394   }
395 
396   // This shuts down the completion queue and joins the thread polling it.
397   // The thread exits only after the completion queue has been drained of all
398   // the events. These events' completion should invoke all remaining RPC
399   // callbacks.
400   // This also deletes all EagerClient instances. There should not be any
401   // references to EagerClients left after all RPCs and async ops have been
402   // finished.
403   remote_eager_workers_ = nullptr;
404 #endif  // !IS_MOBILE_PLATFORM
405 }
406 
~EagerContext()407 EagerContext::~EagerContext() {
408   // TODO(iga): Add a separate API method to shutdown EagerContext so that we
409   // don't send RPCs and block in destructor.
410   WaitForAndCloseRemoteContexts();
411 
412   ClearCachesAndThreadExecutors();
413   for (auto& entry : registered_functions_) {
414     while (!entry.second->Unref()) {
415       // remove all references.
416     }
417   }
418   registered_functions_.clear();
419 
420 #if !defined(IS_MOBILE_PLATFORM)
421   if (server_) {
422     // TODO(b/136478427): Fix this.
423     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
424                     "Servers don't support clean shutdown.";
425     server_.release();
426   }
427 
428   {
429     mutex_lock l(keep_alive_thread_shutdown_mu_);
430     shutting_down_ = true;
431     keep_alive_thread_cv_.notify_all();
432   }
433   keep_alive_thread_.reset();
434   if (!remote_contexts_.empty()) {
435     CloseAndClearAllRemoteContexts();
436   }
437 #endif  // !IS_MOBILE_PLATFORM
438 
439   if (rendezvous_) {
440     rendezvous_->Unref();
441   }
442   if (resource_deallocator_ != nullptr) {
443     resource_deallocator_();
444   }
445 }
446 
FindFunctionByName(const string & name) const447 bool EagerContext::FindFunctionByName(const string& name) const {
448   return func_lib_def_.Find(name) != nullptr;
449 }
450 
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)451 Status EagerContext::FindFunctionOpData(
452     const string& name, const tensorflow::OpRegistrationData** op_data) {
453   return func_lib_def_.LookUp(name, op_data);
454 }
455 
FindFunctionDef(const string & name)456 const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
457   return func_lib_def_.Find(name);
458 }
459 
ListRegisteredFunctions()460 std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
461   std::vector<const FunctionDef*> result;
462   std::vector<string> function_names = func_lib_def_.ListFunctionNames();
463   result.reserve(function_names.size());
464   for (const string& fn : function_names) {
465     result.emplace_back(func_lib_def_.Find(fn));
466   }
467   return result;
468 }
469 
ClearRunMetadata()470 void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
471 
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)472 void EagerContext::ListDevices(
473     std::vector<tensorflow::DeviceAttributes>* devices) {
474   local_device_mgr()->ListDeviceAttributes(devices);
475   if (remote_device_mgr()) {
476     remote_device_mgr()->ListDeviceAttributes(devices);
477   }
478 }
479 
StartStep()480 void EagerContext::StartStep() {
481   mutex_lock ml(metadata_mu_);
482   num_active_steps_++;
483   if (step_container_ == nullptr) {
484     step_container_.reset(
485         new ScopedStepContainer(0, [this](const string& name) {
486           auto local_devices = local_device_mgr()->ListDevices();
487           for (Device* device : local_devices) {
488             device->resource_manager()->Cleanup(name).IgnoreError();
489           }
490         }));
491   }
492 }
493 
EndStep()494 void EagerContext::EndStep() {
495   mutex_lock ml(metadata_mu_);
496   num_active_steps_--;
497   if (num_active_steps_ == 0) {
498     step_container_.reset();
499   }
500 }
501 
StepContainer()502 ScopedStepContainer* EagerContext::StepContainer() {
503   if (num_active_steps_.load() == 0) {
504     return nullptr;
505   }
506   mutex_lock ml(metadata_mu_);
507   return step_container_.get();
508 }
509 
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)510 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
511   // Only client context can register function on remote worker context.
512   if (!remote_device_manager_.Owned()) return Status::OK();
513 #if !defined(IS_MOBILE_PLATFORM)
514   std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
515   request->set_context_id(GetContextId());
516 
517   eager::RegisterFunctionOp* register_function =
518       request->add_queue()->mutable_register_function();
519   *register_function->mutable_function_def() = fdef;
520   StripDefaultAttributes(
521       *OpRegistry::Global(),
522       register_function->mutable_function_def()->mutable_node_def());
523 
524   for (const auto& target : remote_contexts_) {
525     core::RefCountPtr<eager::EagerClient> eager_client;
526     TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
527 
528     eager::EnqueueResponse* response = new eager::EnqueueResponse();
529     eager_client->StreamingEnqueueAsync(
530         request.get(), response, [request, response](const Status& status) {
531           if (!status.ok()) {
532             LOG(ERROR) << "Failed to register function remotely due to "
533                        << status.error_message()
534                        << "\nThis shouldn't happen, please file a bug to "
535                           "tensorflow team.";
536           }
537           delete response;
538         });
539   }
540 #endif  // !IS_MOBILE_PLATFORM
541   return Status::OK();
542 }
543 
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<const FunctionDef * > & function_defs,const std::vector<string> & remote_workers)544 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
545     const std::vector<const FunctionDef*>& function_defs,
546     const std::vector<string>& remote_workers) {
547 #if !defined(IS_MOBILE_PLATFORM)
548   // Register multiple functions on selected remote workers.
549   uint64 context_id = GetContextId();
550   for (int i = 0; i < remote_workers.size(); i++) {
551     core::RefCountPtr<eager::EagerClient> eager_client;
552     Status s =
553         remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
554     if (!s.ok()) {
555       continue;
556     }
557     for (int j = 0; j < function_defs.size(); j++) {
558       auto* request = new eager::EnqueueRequest;
559       request->set_context_id(context_id);
560       eager::RegisterFunctionOp* register_function =
561           request->add_queue()->mutable_register_function();
562       *register_function->mutable_function_def() = *function_defs[j];
563       auto* response = new eager::EnqueueResponse;
564       eager_client->StreamingEnqueueAsync(
565           request, response, [request, response](const Status& s) {
566             if (!s.ok()) {
567               LOG(ERROR) << "Failed to register function remotely due to "
568                          << s.error_message()
569                          << "\nThis shouldn't happen, please file a bug to "
570                             "tensorflow team.";
571             }
572             delete request;
573             delete response;
574           });
575     }
576   }
577 #endif  // !IS_MOBILE_PLATFORM
578   return Status::OK();
579 }
580 
AddFunctionDef(const FunctionDef & fdef)581 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
582   return AddFunctionDef(fdef, FunctionDefLibrary(),
583                         /* add_to_local_only=*/false);
584 }
585 
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only)586 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
587                                     const FunctionDefLibrary& library,
588                                     const bool add_to_local_only) {
589   bool is_first_ref = false;
590   {
591     mutex_lock l(cache_mu_);
592     auto* registered_function =
593         gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
594     if (registered_function == nullptr) {
595       registered_function = new RegisteredFunction;
596       registered_function->cached_kernel_keys =
597           absl::make_unique<std::vector<Fprint128>>();
598       gtl::InsertOrUpdate(&registered_functions_, fdef.signature().name(),
599                           registered_function);
600     } else {
601       registered_function->Ref();
602     }
603     is_first_ref = registered_function->RefCountIsOne();
604   }
605   if (is_first_ref) {
606     TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
607     TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
608     if (!add_to_local_only) {
609       return MaybeRegisterFunctionRemotely(fdef);
610     }
611   }
612   return Status::OK();
613 }
614 
RemoveFunction(const string & func)615 Status EagerContext::RemoveFunction(const string& func) {
616   bool is_last_ref = false;
617   {
618     mutex_lock l(cache_mu_);
619     auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
620     if (registered_function == nullptr) {
621       return errors::InvalidArgument("Tried to remove non-existent function '",
622                                      func, "'.");
623     }
624     is_last_ref = registered_function->RefCountIsOne();
625     if (is_last_ref) {
626       for (auto& key : *registered_function->cached_kernel_keys) {
627         kernel_cache_.erase(key);
628       }
629       registered_functions_.erase(func);
630     }
631     registered_function->Unref();
632   }
633   if (is_last_ref) {
634     // TODO(fishx): Remove remote function as well.
635     return func_lib_def_.RemoveFunction(func);
636   }
637   return Status::OK();
638 }
639 
GetCachedKernel(Fprint128 cache_key)640 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
641     Fprint128 cache_key) {
642   tf_shared_lock l(cache_mu_);
643   auto iter = kernel_cache_.find(cache_key);
644   if (iter == kernel_cache_.end()) {
645     return nullptr;
646   }
647   core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
648   new_ref->Ref();
649   return new_ref;
650 }
651 
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)652 void EagerContext::AddKernelToCache(Fprint128 cache_key,
653                                     KernelAndDevice* kernel) {
654   mutex_lock ml(cache_mu_);
655   core::RefCountPtr<KernelAndDevice> new_ref(kernel);
656   new_ref->Ref();
657   kernel_cache_[cache_key] = std::move(new_ref);
658   auto* registered_function =
659       gtl::FindPtrOrNull(registered_functions_, kernel->name());
660   // The kernel name can be either a primitive op or a function.
661   if (registered_function != nullptr) {
662     registered_function->cached_kernel_keys->emplace_back(cache_key);
663   }
664 }
665 
ShouldStoreGraphs()666 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
667 
SetShouldStoreGraphs(bool value)668 void EagerContext::SetShouldStoreGraphs(bool value) {
669   mutex_lock ml(metadata_mu_);
670   should_store_graphs_.store(value);
671   if (!value) {
672     run_metadata_.Clear();
673   }
674 }
675 
FindDeviceFromName(const char * device_name,Device ** device) const676 Status EagerContext::FindDeviceFromName(const char* device_name,
677                                         Device** device) const {
678   *device = HostCPU();
679   if (device_name == nullptr || strlen(device_name) == 0) {
680     return Status::OK();
681   }
682 
683   auto status = local_device_mgr()->LookupDevice(device_name, device);
684   if (status.ok()) {
685     return status;
686   }
687 
688   if (remote_device_mgr() != nullptr) {
689     return remote_device_mgr()->LookupDevice(device_name, device);
690   }
691 
692   return status;
693 }
694 
OnSameTask(const Device * first,const Device * second) const695 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
696   if (first == nullptr) first = HostCPU();
697   if (second == nullptr) second = HostCPU();
698   return first->parsed_name().job == second->parsed_name().job &&
699          first->parsed_name().replica == second->parsed_name().replica &&
700          first->parsed_name().task == second->parsed_name().task;
701 }
702 
703 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const704 Status EagerContext::CPUDeviceOnTask(const Device* device,
705                                      Device** cpu_device) const {
706   string cpu_device_name;
707   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
708       device->name(), &cpu_device_name));
709 
710   return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
711 }
712 
713 namespace {
GetTaskName(Device * d,string * task_name)714 Status GetTaskName(Device* d, string* task_name) {
715   string ignored;
716   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
717     return errors::InvalidArgument("Unable to parse device name: ", d->name());
718   }
719 
720   return Status::OK();
721 }
722 }  // namespace
723 
724 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)725 Status EagerContext::GetClient(Device* device,
726                                core::RefCountPtr<eager::EagerClient>* client) {
727   return GetClient(device->parsed_name(), client);
728 }
729 
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)730 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
731                                core::RefCountPtr<eager::EagerClient>* client) {
732   if (remote_eager_workers_ == nullptr) {
733     return errors::Internal(
734         "Haven't set up remote eager worker in this eager context yet.");
735   }
736   string device_task_name;
737   if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
738     return errors::InvalidArgument(
739         "Task is not fully specified in device name: ",
740         DeviceNameUtils::ParsedNameToString(device_name));
741   }
742 
743   TF_RETURN_IF_ERROR(
744       remote_eager_workers_->GetClient(device_task_name, client));
745 
746   if (*client == nullptr) {
747     return errors::InvalidArgument(
748         "Unable to find eager client corresponding to device ",
749         DeviceNameUtils::ParsedNameToString(device_name));
750   }
751 
752   if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
753                 device_task_name) == remote_contexts_.end()) {
754     return errors::Internal("Unable to find a context for handle on task: ",
755                             device_task_name, ". This should not be possible");
756   }
757 
758   return Status::OK();
759 }
760 
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)761 Status EagerContext::GetClient(const string& remote_task,
762                                core::RefCountPtr<eager::EagerClient>* client) {
763   if (remote_eager_workers_ == nullptr) {
764     return errors::Internal(
765         "Haven't set up remote eager worker in this eager context yet.");
766   }
767   TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
768 
769   if (*client == nullptr) {
770     return errors::InvalidArgument(
771         "Unable to find eager client corresponding to target ", remote_task);
772   }
773   return Status::OK();
774 }
775 
GetContextId()776 uint64 EagerContext::GetContextId() {
777   tf_shared_lock l(remote_state_mu_);
778   return context_id_;
779 }
780 
GetContextViewId()781 uint64 EagerContext::GetContextViewId() {
782   tf_shared_lock l(remote_state_mu_);
783   return context_view_id_;
784 }
785 
IncrementContextViewId()786 void EagerContext::IncrementContextViewId() {
787   mutex_lock l(remote_state_mu_);
788   context_view_id_ += 1;
789 }
790 
791 // Set collective ops related state in the context. Passing nullptr to
792 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)793 Status EagerContext::StoreCollectiveOpsServer(
794     std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
795     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
796   collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
797 
798   local_device_manager_.Reset(device_mgr);
799   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
800 
801   InitPrioritizedDeviceTypeList();
802   ClearCachesAndThreadExecutors();
803   default_executor_.ClearError();
804   {
805     tensorflow::mutex_lock l(executor_map_mu_);
806     for (auto& entry : thread_local_executor_) {
807       entry.second->ClearError();
808     }
809   }
810 
811   const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
812   ResetPFLR(
813       local_device_manager_.Get(), env_, /*config=*/config,
814       TF_GRAPH_DEF_VERSION, &func_lib_def_,
815       /*optimizer_options=*/
816       config ? config->graph_options().optimizer_options() : OptimizerOptions(),
817       thread_pool_.get());
818 
819   if (new_server != nullptr) {
820     // Memory leak!
821     if (server_ != nullptr) {
822       LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
823                       "Servers don't support clean shutdown.";
824       server_.release();
825     }
826     server_ = std::move(new_server);
827   }
828   DCHECK(server_ != nullptr);
829 
830   return Status::OK();
831 }
832 
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)833 Status EagerContext::SetRemoteDeviceFilters(
834     const string& remote_worker, const std::vector<string>& device_filters) {
835   // Get fully specified task name for remote worker
836   string remote_worker_task_name;
837   DeviceNameUtils::ParsedName pw;
838   if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
839     return tensorflow::errors::InvalidArgument(
840         "Remote worker task name is invalid ", remote_worker);
841   }
842   // Force set a replica as the key in cluster device filters map. I.e., if the
843   // remote worker is `/job:worker/task:0` it then becomes
844   // `/job:worker/replica:0/task:0`.
845   pw.has_replica = true;
846   if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
847     return tensorflow::errors::InvalidArgument(
848         "Job name and task index must be specified for worker ", remote_worker);
849   }
850 
851   std::vector<DeviceNameUtils::ParsedName> parsed_filters;
852   for (auto& filter : device_filters) {
853     DeviceNameUtils::ParsedName parsed_filter;
854     if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
855       parsed_filters.emplace_back(parsed_filter);
856     } else {
857       return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
858     }
859   }
860 
861   if (VLOG_IS_ON(1)) {
862     VLOG(1) << "Setting device filters for " << remote_worker << ":";
863     for (auto& filter : device_filters) {
864       VLOG(1) << "  " << filter;
865     }
866   }
867   mutex_lock l(remote_state_mu_);
868   cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
869   return Status::OK();
870 }
871 
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)872 void EagerContext::FilterDevicesForRemoteWorkers(
873     const string& remote_worker,
874     const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
875     std::vector<bool>* filtered_device_mask) {
876   filtered_device_mask->resize(device_attrs.size());
877   std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
878 
879   tf_shared_lock l(remote_state_mu_);
880   auto it = cluster_device_filters_.find(remote_worker);
881   // If no filters were specified, all devices should be visible to the worker
882   if (it == cluster_device_filters_.end() || it->second.empty()) {
883     std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
884     return;
885   }
886 
887   const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
888   DeviceNameUtils::ParsedName parsed_remote_worker;
889   DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
890   for (int i = 0; i < device_attrs.size(); i++) {
891     DeviceNameUtils::ParsedName pn;
892     DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
893     if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
894       // If this device is on the remote worker itself, it should be visible
895       // regardless of device filters
896       filtered_device_mask->at(i) = true;
897       continue;
898     }
899     for (const auto& pf : parsed_filters) {
900       if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
901           (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
902           (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
903           (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
904           (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
905         // Found a match, make it visible, stop processing more device filters
906         filtered_device_mask->at(i) = true;
907         break;
908       }
909     }
910   }
911 }
912 
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)913 Status EagerContext::InitializeRemoteMaster(
914     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
915     std::shared_ptr<WorkerSession> worker_session,
916     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
917     std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
918     const std::vector<string>& remote_contexts, uint64 context_id,
919     Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
920     DistributedFunctionLibraryRuntime* cluster_flr,
921     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
922         remote_mgr) {
923   if (context_id == kInvalidContextId) {
924     return errors::InvalidArgument(
925         "Failed to initialize remote for master context due to invalid ",
926         "context id");
927   }
928 
929   if (!remote_contexts_.empty()) {
930     CloseAndClearAllRemoteContexts();
931   }
932   remote_contexts_ = remote_contexts;
933 
934   return SetMasterContextState(
935       std::move(server), worker_env, std::move(worker_session),
936       std::move(remote_eager_workers), std::move(remote_device_manager),
937       context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
938       std::move(remote_mgr));
939 }
940 
UpdateRemoteMaster(WorkerEnv * worker_env,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr)941 Status EagerContext::UpdateRemoteMaster(
942     WorkerEnv* worker_env,
943     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
944     const std::vector<string>& add_remote_contexts,
945     const std::vector<string>& remove_remote_contexts, uint64 context_id,
946     Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
947     DistributedFunctionLibraryRuntime* cluster_flr) {
948   {
949     tf_shared_lock l(remote_state_mu_);
950     if (context_id != context_id_) {
951       return errors::InvalidArgument(
952           "Failed to update remote remote master context due to invalid ",
953           "context id. Request id = ", context_id,
954           " but current id = ", context_id_);
955     }
956   }
957 
958   if (!remove_remote_contexts.empty()) {
959     // N.B. remove_remote_contexts include both removed and replaced workers.
960     // In the case where a worker is replaced by one that resolves to the same
961     // `hostname:port`, it is safe to close context with the current view id,
962     // since the newly created context on the remote worker will be holding
963     // a larger view id and ignores this request.
964     CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
965     for (const string& remote_context : remove_remote_contexts) {
966       remote_contexts_.erase(
967           std::remove(remote_contexts_.begin(), remote_contexts_.end(),
968                       remote_context),
969           remote_contexts_.end());
970     }
971   }
972   if (!add_remote_contexts.empty()) {
973     remote_contexts_.insert(std::end(remote_contexts_),
974                             std::begin(add_remote_contexts),
975                             std::end(add_remote_contexts));
976   }
977   std::vector<const FunctionDef*> function_defs = ListRegisteredFunctions();
978 
979   {
980     mutex_lock l(remote_state_mu_);
981     context_view_id_++;
982 
983     worker_env_ = worker_env;
984     if (rendezvous_ != nullptr) rendezvous_->Unref();
985     rendezvous_ = r;
986     remote_eager_workers_ = std::move(remote_eager_workers);
987     ResetClusterFLR(cluster_flr);
988     InitPrioritizedDeviceTypeList();
989 
990     default_executor_.ClearError();
991     {
992       tensorflow::mutex_lock l(executor_map_mu_);
993       for (auto& entry : thread_local_executor_) {
994         entry.second->ClearError();
995       }
996     }
997     const auto* config = pflr_->config();
998     ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
999               &func_lib_def_, config->graph_options().optimizer_options(),
1000               thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1001   }
1002 
1003   // Register existing functions to the newly added remote workers. Note that
1004   // this should happen only after updating `remote_contexts_` because new
1005   // functions might be registered while we update the context. When that
1006   // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1007   // register the new functions on all remote workers (including the newly added
1008   // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1009   // registering existing functions, where duplicate registrations will be
1010   // ignored by the remote workers.
1011   TF_RETURN_IF_ERROR(RegisterExistingFunctionsOnRemoteWorkers(
1012       function_defs, add_remote_contexts));
1013   return Status::OK();
1014 }
1015 
1016 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1017 Status EagerContext::SetMasterContextState(
1018     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1019     std::shared_ptr<WorkerSession> worker_session,
1020     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1021     std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1022     uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1023     int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1024     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1025         remote_mgr) {
1026   mutex_lock l(remote_state_mu_);
1027   is_master_ = true;
1028   context_id_ = context_id;
1029   context_view_id_ = context_view_id;
1030 
1031   use_send_tensor_rpc_ =
1032       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1033 
1034   local_device_manager_.Reset(local_device_mgr);
1035   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1036 
1037   if (rendezvous_ != nullptr) rendezvous_->Unref();
1038   rendezvous_ = r;
1039 
1040   // Memory leak!
1041   if (server_ != nullptr) {
1042     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1043                     "Servers don't support clean shutdown.";
1044     server_.release();
1045   }
1046   server_ = std::move(server);
1047 
1048   remote_mgr_ = std::move(remote_mgr);
1049   worker_env_ = worker_env;
1050   worker_session_ = std::move(worker_session);
1051   remote_eager_workers_ = std::move(remote_eager_workers);
1052 
1053   remote_device_manager_.Reset(std::move(remote_device_manager));
1054   ResetClusterFLR(cluster_flr);
1055 
1056   InitPrioritizedDeviceTypeList();
1057 
1058   ClearCachesAndThreadExecutors();
1059   default_executor_.ClearError();
1060   {
1061     tensorflow::mutex_lock l(executor_map_mu_);
1062     for (auto& entry : thread_local_executor_) {
1063       entry.second->ClearError();
1064     }
1065   }
1066   const auto* config = pflr_->config();
1067   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1068             &func_lib_def_, config->graph_options().optimizer_options(),
1069             thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1070 
1071   keep_alive_secs_ = keep_alive_secs;
1072   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1073   // Only schedule a single closure.
1074   if (keep_alive_thread_ == nullptr) {
1075     keep_alive_thread_.reset(
1076         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1077           while (true) {
1078             {
1079               {
1080                 mutex_lock l(keep_alive_thread_shutdown_mu_);
1081 
1082                 if (shutting_down_) {
1083                   return;
1084                 }
1085 
1086                 keep_alive_thread_cv_.wait_for(
1087                     l, std::chrono::seconds(sleep_for_secs_));
1088 
1089                 if (shutting_down_) {
1090                   return;
1091                 }
1092               }
1093               {
1094                 mutex_lock l(remote_state_mu_);
1095                 if (keep_alive_secs_ > 0) {
1096                   {
1097                     for (const auto& worker : remote_contexts_) {
1098                       core::RefCountPtr<eager::EagerClient> client;
1099                       Status s =
1100                           remote_eager_workers_->GetClient(worker, &client);
1101 
1102                       if (!s.ok()) {
1103                         LOG(WARNING) << "Keep-alive thread was unable to find "
1104                                         "a client for target "
1105                                      << worker << ". Got error: " << s;
1106                         continue;
1107                       }
1108 
1109                       eager::KeepAliveRequest* request =
1110                           new eager::KeepAliveRequest;
1111                       eager::KeepAliveResponse* response =
1112                           new eager::KeepAliveResponse;
1113 
1114                       request->set_context_id(context_id_);
1115                       client->KeepAliveAsync(
1116                           request, response,
1117                           [request, response](const Status& s) {
1118                             delete request;
1119                             delete response;
1120                           });
1121                     }
1122                   }
1123                 }
1124               }
1125             }
1126           }
1127         }));
1128   }
1129   return Status::OK();
1130 }
1131 
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1132 Status EagerContext::InitializeRemoteWorker(
1133     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1134     DynamicDeviceMgr* remote_device_mgr,
1135     const std::vector<string>& remote_contexts, uint64 context_id,
1136     uint64 context_view_id,
1137     std::function<Rendezvous*(const int64)> rendezvous_creator,
1138     DistributedFunctionLibraryRuntime* cluster_flr,
1139     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1140         remote_mgr,
1141     std::function<void()> resource_deallocator) {
1142   if (context_id == kInvalidContextId) {
1143     return errors::InvalidArgument(
1144         "Failed to initialize remote for worker context due to invalid ",
1145         "context id");
1146   }
1147   mutex_lock l(remote_state_mu_);
1148 
1149   if (remote_device_manager_.Owned() || server_ != nullptr ||
1150       keep_alive_thread_ != nullptr) {
1151     return errors::FailedPrecondition(
1152         "EagerContext::InitializeRemoteWorker Failed. ",
1153         "Already initialized remote as a master context.");
1154   }
1155   is_master_ = false;
1156 
1157   remote_contexts_ = remote_contexts;
1158   context_id_ = context_id;
1159   context_view_id_ = context_view_id;
1160 
1161   rendezvous_creator_ = std::move(rendezvous_creator);
1162   remote_eager_workers_ = std::move(remote_eager_workers);
1163   remote_mgr_ = std::move(remote_mgr);
1164   ResetClusterFLR(cluster_flr);
1165 
1166   remote_device_manager_.Reset(remote_device_mgr);
1167 
1168   const auto* config = pflr_->config();
1169   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1170             &func_lib_def_, config->graph_options().optimizer_options(),
1171             thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1172   InitPrioritizedDeviceTypeList();
1173 
1174   ClearCachesAndThreadExecutors();
1175   default_executor_.ClearError();
1176   {
1177     tensorflow::mutex_lock l(executor_map_mu_);
1178     for (auto& entry : thread_local_executor_) {
1179       entry.second->ClearError();
1180     }
1181   }
1182 
1183   resource_deallocator_ = std::move(resource_deallocator);
1184 
1185   return Status::OK();
1186 }
1187 
UpdateRemoteWorker(const DeviceMgr * worker_session_device_mgr,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,DistributedFunctionLibraryRuntime * cluster_flr)1188 Status EagerContext::UpdateRemoteWorker(
1189     const DeviceMgr* worker_session_device_mgr,
1190     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1191     DynamicDeviceMgr* remote_device_mgr,
1192     const std::vector<string>& remote_contexts, uint64 context_id,
1193     DistributedFunctionLibraryRuntime* cluster_flr) {
1194   {
1195     mutex_lock l(remote_state_mu_);
1196     if (context_id != context_id_) {
1197       return errors::InvalidArgument(
1198           "Failed to update remote for worker context due to invalid ",
1199           "context id. Request id = ", context_id,
1200           " but current id = ", context_id_);
1201     }
1202     context_view_id_++;
1203   }
1204 
1205   remote_contexts_ = remote_contexts;
1206 
1207   remote_eager_workers_ = std::move(remote_eager_workers);
1208   ResetClusterFLR(cluster_flr);
1209 
1210   remote_device_manager_.Reset(remote_device_mgr);
1211   InitPrioritizedDeviceTypeList();
1212 
1213   ClearCachesAndThreadExecutors();
1214   default_executor_.ClearError();
1215   {
1216     tensorflow::mutex_lock l(executor_map_mu_);
1217     for (auto& entry : thread_local_executor_) {
1218       entry.second->ClearError();
1219     }
1220   }
1221 
1222   SessionOptions options = SessionOptions();
1223   const auto* config = pflr_->config();
1224   ResetPFLR(worker_session_device_mgr, options.env, config,
1225             TF_GRAPH_DEF_VERSION, FuncLibDef(),
1226             config->graph_options().optimizer_options(), thread_pool_.get(),
1227             cluster_flr_.Get(), custom_kernel_creator_);
1228   return Status::OK();
1229 }
1230 #endif  // !IS_MOBILE_PLATFORM
1231 
1232 }  // namespace tensorflow
1233