• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/c/eager/immediate_execution_context.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
26 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/nccl/collective_communicator.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/platform.h"
33 // clang-format on
34 
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
38 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
39 #include "tensorflow/core/common_runtime/colocation_graph.h"
40 #include "tensorflow/core/common_runtime/device_resolver_local.h"
41 #include "tensorflow/core/common_runtime/device_set.h"
42 #include "tensorflow/core/common_runtime/process_util.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #if !defined(IS_MOBILE_PLATFORM)
50 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
51 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
52 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
53 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
54 #endif  // !IS_MOBILE_PLATFORM
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/lib/core/blocking_counter.h"
57 #include "tensorflow/core/lib/monitoring/gauge.h"
58 #include "tensorflow/core/util/env_var.h"
59 
60 namespace tensorflow {
61 namespace {
62 
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)63 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
64   bool val;
65   if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
66     return val;
67   }
68   return default_val;
69 }
70 
71 auto* eager_context_created =
72     monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
73                                     "True if an eager context was created.");
74 
75 }  // namespace
76 
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr,CollectiveExecutorMgrInterface * collective_executor_mgr,bool run_eager_op_as_function)77 EagerContext::EagerContext(
78     const SessionOptions& opts,
79     ContextDevicePlacementPolicy default_device_placement_policy, bool async,
80     DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
81     DistributedFunctionLibraryRuntime* cluster_flr,
82     CollectiveExecutorMgrInterface* collective_executor_mgr,
83     bool run_eager_op_as_function)
84     : ImmediateExecutionContext(kEager),
85       opts_(opts),
86       default_device_placement_policy_(default_device_placement_policy),
87       local_device_manager_(device_mgr, device_mgr_owned),
88       host_cpu_device_(device_mgr->HostCPU()),
89       rendezvous_(rendezvous),
90       thread_pool_(NewThreadPoolFromSessionOptions(opts)),
91       cluster_flr_(cluster_flr),
92       log_device_placement_(opts.config.log_device_placement()),
93       allow_soft_placement_(opts.config.allow_soft_placement()),
94       num_active_steps_(0),
95       step_container_(std::make_unique<ScopedStepContainer>(
96           0, [this](const string& name) { ClearResourceContainer(name); })),
97       default_executor_(async),
98       log_memory_(LogMemory::IsEnabled()),
99       env_(opts.env),
100       collective_executor_mgr_(collective_executor_mgr, /*owned=*/false),
101       use_send_tensor_rpc_(false),
102       pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
103           "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)),
104       run_eager_op_as_function_(run_eager_op_as_function) {
105   ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
106             &func_lib_def_, opts.config.graph_options().optimizer_options(),
107             thread_pool_.get(), cluster_flr);
108   // Starts exporting metrics through a platform-specific monitoring API (if
109   // provided). For builds using "tensorflow/core/platform/default", this is
110   // currently a no-op.
111   eager_context_created->GetCell()->Set(true);
112   InitPrioritizedDeviceTypeList();
__anon4433f2400302(std::function<void()> closure) 113   runner_ = [this](std::function<void()> closure) {
114     this->thread_pool_->Schedule(std::move(closure));
115   };
116 
117   run_metadata_ = std::make_unique<RunMetadata>();
118 
119 #if !defined(IS_MOBILE_PLATFORM)
120   context_id_ = kInvalidContextId;
121   context_view_id_ = 0;
122 #endif  // IS_MOBILE_PLATFORM
123 
124   // TODO(yuefengz): consider creating a new RpcCollectiveExecutorMgr each
125   // time.
126   if (collective_executor_mgr_.Get() == nullptr) {
127     collective_executor_mgr_.Reset(CreateProdLocalCollectiveExecutorMgr(
128         opts.config, local_device_mgr(),
129         MaybeCreateNcclCommunicator(opts.config)));
130   }
131   global_rendezvous_for_functions_ =
132       core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
133 }
134 
CreateInt64Scalar(int64_t value)135 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64_t value) {
136   return new TensorInterface(Tensor(value));
137 }
138 
CreateUint64Scalar(uint64 value)139 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
140   return new TensorInterface(Tensor(value));
141 }
142 
CreateInt32Scalar(int32_t value)143 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32_t value) {
144   return new TensorInterface(Tensor(value));
145 }
146 
CreateFloatScalar(float value)147 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
148   return new TensorInterface(Tensor(value));
149 }
150 
CreateDoubleScalar(double value)151 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
152   return new TensorInterface(Tensor(value));
153 }
154 
CreateHalfScalar(Eigen::half value)155 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
156   return new TensorInterface(Tensor(value));
157 }
158 
CreateStringScalar(tstring value)159 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
160   return new TensorInterface(Tensor(value));
161 }
162 
CreateComplex128Scalar(complex128 value)163 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
164     complex128 value) {
165   return new TensorInterface(Tensor(value));
166 }
167 
CreateBoolScalar(bool value)168 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
169   return new TensorInterface(Tensor(value));
170 }
171 
CreateTensor(DataType dtype,absl::Span<const int64> dim_sizes)172 AbstractTensorInterface* EagerContext::CreateTensor(
173     DataType dtype, absl::Span<const int64> dim_sizes) {
174   return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
175 }
176 
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)177 AbstractTensorInterface* EagerContext::CreateTensor(
178     DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
179     MemoryReleaser memory_releaser, void* memory_releaser_arg) {
180   TF_Tensor* tensor_wrapper =
181       TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
182                    memory_releaser, memory_releaser_arg);
183 
184   AbstractTensorInterface* result = nullptr;
185   std::swap(result, tensor_wrapper->tensor);
186   TF_DeleteTensor(tensor_wrapper);
187   return result;
188 }
189 
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)190 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
191                              const ConfigProto* config, int graph_def_version,
192                              const FunctionLibraryDefinition* lib_def,
193                              const OptimizerOptions& optimizer_options,
194                              thread::ThreadPool* thread_pool,
195                              DistributedFunctionLibraryRuntime* cluster_flr) {
196   Rendezvous::Factory rendezvous_factory{
197       [this](const int64_t step_id, const DeviceMgr*, Rendezvous** r) {
198         *r = CreateRendezvous(step_id);
199         return Status::OK();
200       }};
201   pflr_.reset(new ProcessFunctionLibraryRuntime(
202       device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
203       thread_pool, cluster_flr,
204       /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
205 }
206 
InitPrioritizedDeviceTypeList()207 void EagerContext::InitPrioritizedDeviceTypeList() {
208   DeviceSet ds;
209   for (Device* d : local_device_mgr()->ListDevices()) {
210     ds.AddDevice(d);
211   }
212   auto remote_device_manager = remote_device_mgr();
213   if (remote_device_manager != nullptr) {
214     for (Device* d : remote_device_manager->ListDevices()) {
215       ds.AddDevice(d);
216     }
217   }
218   mutex_lock l(device_type_list_mu_);
219   prioritized_device_type_list_ =
220       std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
221 }
222 
223 namespace {
224 // Using absl::StrJoin with lambda does not work in tf-lite builds.
225 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)226 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
227   std::vector<string> v;
228   v.reserve(devices.size());
229   for (const auto& p : devices) {
230     v.push_back(p.first->name());
231   }
232   return v;
233 }
234 
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)235 std::vector<string> DeviceTypesToString(
236     const PrioritizedDeviceTypeVector& types) {
237   std::vector<string> v;
238   v.reserve(types.size());
239   for (const auto& p : types) {
240     v.push_back(p.first.type_string());
241   }
242   return v;
243 }
244 
245 // Selects the "best" device that both exists and is supported.
246 //
247 // The `existing` argument specifies the available devices in the system, in
248 // priority order. The `supported` argument specifies the supported device types
249 // and their priorities, lower index types having higher priority.
250 // Currently the type priority defined by the `supported` parameter takes
251 // precedence over system device priorities from `existing`.
252 //
253 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)254 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
255                                  const PrioritizedDeviceVector& existing,
256                                  const PrioritizedDeviceTypeVector& supported) {
257   for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
258     for (const std::pair<Device*, int32>& prioritized_device : existing) {
259       Device* dev = prioritized_device.first;
260       if (DeviceType(dev->attributes().device_type()) ==
261               prioritized_type.first &&
262           DeviceNameUtils::IsCompleteSpecification(pattern,
263                                                    dev->parsed_name())) {
264         return dev;
265       }
266     }
267   }
268   return nullptr;
269 }
270 
271 }  // namespace
272 
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const273 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
274                                   const NodeDef& ndef, Device** out) const {
275   DCHECK(out != nullptr);
276 
277   PrioritizedDeviceTypeVector supported_devs;
278   auto device_type_list = prioritized_device_type_list();
279   TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
280       *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
281   if (supported_devs.empty()) {
282     return errors::NotFound("Could not find device for node: ",
283                             errors::FormatNodeNameForError(ndef.name()), " = ",
284                             ndef.op(), "[", SummarizeAttrs(ndef), "]",
285                             "\nAll kernels registered for op ", ndef.op(),
286                             ":\n", KernelsRegisteredForOp(ndef.op()));
287   }
288 
289   // Select the first matching registered device from the supported device
290   // list. If nothing matches and soft placement is enabled, pick a suitable
291   // device from the available ones.
292   const auto pflr_device_set = pflr()->device_set();
293   const PrioritizedDeviceVector& existing =
294       pflr_device_set->prioritized_devices();
295   *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
296   if (*out != nullptr) {
297     return Status::OK();
298   }
299 
300   if (AllowSoftPlacement()) {
301     DeviceNameUtils::ParsedName soft_device_name = preferred;
302     soft_device_name.type.clear();
303     soft_device_name.has_type = false;
304     soft_device_name.has_id = false;
305     // TODO(b/148213746): Soft placement logic picks up another task if the
306     // requested does not exist.
307     *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
308     if (*out != nullptr) {
309       return Status::OK();
310     }
311   }
312 
313   if (DeviceNameUtils::HasSomeDetails(preferred)) {
314     return errors::InvalidArgument(
315         "Could not satisfy device specification '", preferred,
316         "'. enable_soft_placement=", AllowSoftPlacement(),
317         ". Supported device types [",
318         absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
319         "]. All available devices [",
320         absl::StrJoin(DevicesToString(existing), ", "), "].");
321   }
322   return errors::InvalidArgument(
323       "No supported device found in available devices [",
324       absl::StrJoin(DevicesToString(existing), ", "),
325       "]. enable_soft_placement=", AllowSoftPlacement(),
326       ". Supported devices types [",
327       absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
328 }
329 
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)330 void EagerContext::ResetClusterFLR(
331     DistributedFunctionLibraryRuntime* cluster_flr) {
332   cluster_flr_.Reset(cluster_flr, /*owned=*/true);
333 }
334 
Executor()335 EagerExecutor& EagerContext::Executor() {
336   tf_shared_lock l(executor_map_mu_);
337   return *gtl::FindWithDefault(thread_local_executor_,
338                                std::this_thread::get_id(), &default_executor_);
339 }
340 
SetExecutorForThread(EagerExecutor * executor)341 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
342   tensorflow::mutex_lock l(executor_map_mu_);
343   if (executor == &default_executor_) {
344     thread_local_executor_.erase(std::this_thread::get_id());
345   } else {
346     auto thread_id = std::this_thread::get_id();
347     thread_local_executor_[thread_id] = executor;
348     auto& executors_with_cleanups = has_cleanup_[thread_id];
349     if (executors_with_cleanups.find(executor) ==
350         executors_with_cleanups.end()) {
351       executors_with_cleanups.insert(executor);
352       // If the executor is deleted before this context, we need to remove it
353       // from the map to avoid attempting to sync it in our destructor.
354       std::function<void()> cleanup([this, thread_id, executor]() {
355         {
356           tensorflow::mutex_lock l(executor_map_mu_);
357           auto existing = thread_local_executor_.find(thread_id);
358           if (existing != thread_local_executor_.end() &&
359               existing->second == executor) {
360             thread_local_executor_.erase(thread_id);
361           }
362           has_cleanup_[thread_id].erase(executor);
363         }
364       });
365       executor->AddCleanup(reinterpret_cast<intptr_t>(this),
366                            std::move(cleanup));
367     }
368   }
369 }
370 
ClearCachesAndThreadExecutors()371 void EagerContext::ClearCachesAndThreadExecutors() {
372   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
373   {
374     mutex_lock l(executor_map_mu_);
375     executors_copy = thread_local_executor_;
376   }
377   for (const auto& entry : executors_copy) {
378     entry.second->WaitForAllPendingNodes().IgnoreError();
379   }
380   ClearCachesAndDefaultExecutor();
381 }
382 
ClearCachesAndDefaultExecutor()383 void EagerContext::ClearCachesAndDefaultExecutor() {
384   // The executor stores pointers to kernels, so we need to make sure that no
385   // async eager ops are still executing. We lock the cache during this time
386   // as well.
387   mutex_lock ml(cache_mu_);
388   default_executor_.WaitForAllPendingNodes().IgnoreError();
389   kernel_cache_.clear();
390   for (auto& entry : registered_functions_) {
391     entry.second->cached_kernel_keys->clear();
392   }
393   {
394     mutex_lock ml(metadata_mu_);
395     step_container_.reset(new ScopedStepContainer(
396         0, [this](const string& name) { ClearResourceContainer(name); }));
397   }
398 }
399 
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)400 void EagerContext::SetThreadLocalDevicePlacementPolicy(
401     ContextDevicePlacementPolicy policy) {
402   mutex_lock ml(policy_map_mu_);
403   device_placement_policy_[std::this_thread::get_id()] = policy;
404 }
405 
GetDevicePlacementPolicy() const406 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
407   tf_shared_lock l(policy_map_mu_);
408   auto policy_map_it =
409       device_placement_policy_.find(std::this_thread::get_id());
410   if (policy_map_it != device_placement_policy_.end()) {
411     return policy_map_it->second;
412   }
413   return default_device_placement_policy_;
414 }
415 
416 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()417 std::vector<string> EagerContext::GetRemoteContexts() {
418   tf_shared_lock l(remote_state_mu_);
419   return remote_contexts_;
420 }
421 
IsRemoteContextsEmpty()422 bool EagerContext::IsRemoteContextsEmpty() {
423   tf_shared_lock l(remote_state_mu_);
424   return remote_contexts_.empty();
425 }
426 
CloseAndClearAllRemoteContexts()427 void EagerContext::CloseAndClearAllRemoteContexts() {
428   uint64 context_id;
429   uint64 context_view_id;
430   std::vector<string> remote_contexts_copy;
431   {
432     mutex_lock l(remote_state_mu_);
433     if (!is_master_) return;
434     context_id = context_id_;
435     context_view_id = context_view_id_;
436     context_id_ = kInvalidContextId;
437     // Forget the current view id and reset to the starting value 0.
438     context_view_id_ = 0;
439 
440     // Make a copy of remote targets to avoid holding the lock when sending
441     // close context requests.
442     remote_contexts_copy = remote_contexts_;
443     remote_contexts_.clear();
444   }
445   CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
446 }
447 
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)448 void EagerContext::CloseRemoteContexts(
449     const std::vector<string>& remote_contexts, uint64 context_id,
450     uint64 context_view_id) {
451   // Close all remote contexts.
452   eager::CloseContextRequest request;
453   request.set_context_id(context_id);
454   request.set_context_view_id(context_view_id);
455   // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
456   // request to closed remote workers.
457   std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
458   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
459 
460   int i = 0;
461   for (const auto& worker : remote_contexts) {
462     core::RefCountPtr<eager::EagerClient> client;
463     Status s = GetClient(worker, &client);
464 
465     client->CloseContextAsync(
466         &request, &responses[i],
467         [&worker, &counter, context_id](const Status& s) {
468           if (!s.ok()) {
469             LOG(ERROR) << "Unable to close remote context with ID "
470                        << context_id << " for worker: " << worker << " due to "
471                        << s.error_message();
472           }
473           counter.DecrementCount();
474         });
475     i++;
476   }
477 
478   counter.Wait();
479 }
480 
481 #endif  // !IS_MOBILE_PLATFORM
482 
WaitForAndCloseRemoteContexts()483 void EagerContext::WaitForAndCloseRemoteContexts() {
484   ClearCachesAndThreadExecutors();
485 
486 #if !defined(IS_MOBILE_PLATFORM)
487   {
488     mutex_lock l(keep_alive_thread_shutdown_mu_);
489     shutting_down_ = true;
490     keep_alive_thread_cv_.notify_all();
491   }
492   keep_alive_thread_.reset();
493 
494   if (!IsRemoteContextsEmpty()) {
495     CloseAndClearAllRemoteContexts();
496   }
497 
498   {
499     mutex_lock l(remote_state_mu_);
500 
501     default_executor_.ShutDown().IgnoreError();
502     std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
503     {
504       mutex_lock l(executor_map_mu_);
505       executors_copy = thread_local_executor_;
506     }
507     for (const auto& it : executors_copy) {
508       it.second->ShutDown().IgnoreError();
509     }
510 
511     // This shuts down the completion queue and joins the thread polling it.
512     // The thread exits only after the completion queue has been drained of all
513     // the events. These events' completion should invoke all remaining RPC
514     // callbacks.
515     // This also deletes all EagerClient instances. There should not be any
516     // references to EagerClients left after all RPCs and async ops have been
517     // finished.
518     remote_eager_workers_ = nullptr;
519   }
520 #endif  // !IS_MOBILE_PLATFORM
521 }
522 
~EagerContext()523 EagerContext::~EagerContext() {
524   // TODO(iga): Add a separate API method to shutdown EagerContext so that we
525   // don't send RPCs and block in destructor.
526   WaitForAndCloseRemoteContexts();
527 
528   // Custom devices may have obtained references to various context components
529   // (executors, thread pool). It's safer to run their destructors early.
530   custom_device_op_handler_.Clear();
531 
532   ClearCachesAndThreadExecutors();
533   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
534   {
535     mutex_lock l(executor_map_mu_);
536     executors_copy = thread_local_executor_;
537   }
538   for (const auto& entry : executors_copy) {
539     // Let the executor know that its cleanup closure is no longer valid.
540     entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
541   }
542   for (auto& entry : registered_functions_) {
543     while (!entry.second->Unref()) {
544       // remove all references.
545     }
546   }
547   registered_functions_.clear();
548 
549 #if !defined(IS_MOBILE_PLATFORM)
550   if (server_) {
551     // TODO(b/136478427): Fix this.
552     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
553                     "Servers don't support clean shutdown.";
554     server_.release();
555   }
556 
557   {
558     mutex_lock l(keep_alive_thread_shutdown_mu_);
559     shutting_down_ = true;
560     keep_alive_thread_cv_.notify_all();
561   }
562   keep_alive_thread_.reset();
563   if (!remote_contexts_.empty()) {
564     CloseAndClearAllRemoteContexts();
565   }
566 #endif  // !IS_MOBILE_PLATFORM
567 
568   if (rendezvous_) {
569     rendezvous_->Unref();
570   }
571   if (resource_deallocator_ != nullptr) {
572     resource_deallocator_();
573   }
574 }
575 
FindFunctionByName(const string & name) const576 bool EagerContext::FindFunctionByName(const string& name) const {
577   return func_lib_def_.Find(name) != nullptr;
578 }
579 
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)580 Status EagerContext::FindFunctionOpData(
581     const string& name, const tensorflow::OpRegistrationData** op_data) {
582   return func_lib_def_.LookUp(name, op_data);
583 }
584 
FindFunctionDef(const string & name) const585 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
586   return func_lib_def_.Find(name);
587 }
588 
ExportRunMetadata()589 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
590   mutex_lock ml(metadata_mu_);
591   auto result = std::make_unique<RunMetadata>();
592   run_metadata_.swap(result);
593   return result;
594 }
595 
UsesTFRT()596 bool EagerContext::UsesTFRT() { return false; }
597 
RunEagerOpAsFunction() const598 bool EagerContext::RunEagerOpAsFunction() const {
599   return run_eager_op_as_function_;
600 }
601 
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)602 void EagerContext::ListDevices(
603     std::vector<tensorflow::DeviceAttributes>* devices) {
604   local_device_mgr()->ListDeviceAttributes(devices);
605   if (remote_device_mgr()) {
606     remote_device_mgr()->ListDeviceAttributes(devices);
607   }
608 }
609 
AddDevices(std::vector<std::unique_ptr<Device>> devices)610 Status EagerContext::AddDevices(std::vector<std::unique_ptr<Device>> devices) {
611   std::vector<std::unique_ptr<Device>> local_devices, remote_devices;
612   while (!devices.empty()) {
613     if (devices.front()->IsLocal()) {
614       local_devices.push_back(std::move(devices.front()));
615     } else {
616       remote_devices.push_back(std::move(devices.front()));
617     }
618     devices.erase(devices.begin());
619   }
620   TF_RETURN_IF_ERROR(
621       reinterpret_cast<DynamicDeviceMgr*>(local_device_manager_.Get())
622           ->AddDevices(std::move(local_devices)));
623 
624   if (!remote_devices.empty()) {
625     if (!remote_device_mgr()) {
626       remote_device_manager_.Reset(
627           std::make_unique<tensorflow::DynamicDeviceMgr>());
628     }
629     TF_RETURN_IF_ERROR(
630         reinterpret_cast<DynamicDeviceMgr*>(remote_device_manager_.Get())
631             ->AddDevices(std::move(remote_devices)));
632   }
633 
634   // Add the devices to pflr's device set.
635   pflr_->InitializeDeviceAndFlr();
636   InitPrioritizedDeviceTypeList();
637   return Status::OK();
638 }
639 
StartStep()640 void EagerContext::StartStep() {
641   mutex_lock ml(metadata_mu_);
642   num_active_steps_++;
643 }
644 
EndStep()645 void EagerContext::EndStep() {
646   mutex_lock ml(metadata_mu_);
647   num_active_steps_--;
648   if (num_active_steps_ == 0) {
649     // TODO(b/139809335): This does not properly clean up remote resources
650     // Clean up the previous step container and create a new one.
651     step_container_.reset(new ScopedStepContainer(
652         0, [this](const string& name) { ClearResourceContainer(name); }));
653   }
654 }
655 
StepContainer()656 ScopedStepContainer* EagerContext::StepContainer() {
657   mutex_lock ml(metadata_mu_);
658   return step_container_.get();
659 }
660 
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)661 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
662   // Only client context can register function on remote worker context.
663   if (!remote_device_manager_.Owned()) return Status::OK();
664 #if !defined(IS_MOBILE_PLATFORM)
665   std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
666   request->set_context_id(GetContextId());
667 
668   eager::RegisterFunctionOp* register_function =
669       request->add_queue()->mutable_register_function();
670   *register_function->mutable_function_def() = fdef;
671   StripDefaultAttributes(
672       *OpRegistry::Global(),
673       register_function->mutable_function_def()->mutable_node_def());
674 
675   auto remote_contexts = GetRemoteContexts();
676   for (const auto& target : remote_contexts) {
677     core::RefCountPtr<eager::EagerClient> eager_client;
678     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
679 
680     eager::EnqueueResponse* response = new eager::EnqueueResponse();
681     eager_client->StreamingEnqueueAsync(
682         /*call_opts=*/nullptr, request.get(), response,
683         [request, response](const Status& status) {
684           if (!status.ok()) {
685             LOG(ERROR) << "Failed to register function remotely due to "
686                        << status.error_message()
687                        << "\nThis could happen if the remote target has been "
688                           "disconnected from the client.";
689           }
690           delete response;
691         });
692   }
693 #endif  // !IS_MOBILE_PLATFORM
694   return Status::OK();
695 }
696 
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)697 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
698     const std::vector<string>& remote_workers) {
699 #if !defined(IS_MOBILE_PLATFORM)
700   // Register multiple functions on selected remote workers.
701   uint64 context_id = GetContextId();
702   FunctionDefLibrary function_defs = func_lib_def_.ToProto();
703   std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
704       function_defs.function_size());
705   for (int i = 0; i < function_defs.function_size(); i++) {
706     requests[i] = std::make_shared<eager::EnqueueRequest>();
707     requests[i]->set_context_id(context_id);
708     eager::RegisterFunctionOp* register_function =
709         requests[i]->add_queue()->mutable_register_function();
710     *register_function->mutable_function_def() =
711         std::move(*function_defs.mutable_function(i));
712     StripDefaultAttributes(
713         *OpRegistry::Global(),
714         register_function->mutable_function_def()->mutable_node_def());
715   }
716 
717   for (auto& remote_worker : remote_workers) {
718     core::RefCountPtr<eager::EagerClient> eager_client;
719     Status s = GetClient(remote_worker, &eager_client);
720     if (!s.ok()) {
721       continue;
722     }
723     for (int i = 0; i < requests.size(); i++) {
724       auto response = std::make_shared<eager::EnqueueResponse>();
725       eager_client->StreamingEnqueueAsync(
726           /*call_opts=*/nullptr, requests[i].get(), response.get(),
727           [request = requests[i], response](const Status& s) {
728             if (!s.ok()) {
729               LOG(ERROR) << "Failed to register function remotely due to "
730                          << s.error_message()
731                          << "\nThis could happen if the remote target has been "
732                             "disconnected from the client.";
733             }
734           });
735     }
736   }
737 #endif  // !IS_MOBILE_PLATFORM
738   return Status::OK();
739 }
740 
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)741 Status EagerContext::AddFunctionDefWithStackTraces(
742     const FunctionDef& fdef, const StackTracesMap& stack_traces) {
743   return AddFunctionDef(fdef, FunctionDefLibrary(),
744                         /* add_to_local_only=*/false, stack_traces);
745 }
746 
AddFunctionDef(const FunctionDef & fdef)747 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
748   return AddFunctionDef(fdef, FunctionDefLibrary(),
749                         /* add_to_local_only=*/false);
750 }
751 
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)752 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
753                                     const FunctionDefLibrary& library,
754                                     const bool add_to_local_only,
755                                     const StackTracesMap& stack_traces) {
756   bool is_first_ref = false;
757   {
758     mutex_lock l(cache_mu_);
759     auto* registered_function =
760         gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
761     if (registered_function == nullptr) {
762       registered_function = new RegisteredFunction;
763       registered_function->cached_kernel_keys =
764           absl::make_unique<std::vector<Fprint128>>();
765       gtl::InsertOrUpdate(&registered_functions_, fdef.signature().name(),
766                           registered_function);
767     } else {
768       // The function has been registered before. If the function is the same,
769       // then we take a Ref() otherwise we error out.
770       const FunctionDef* prev_fdef =
771           func_lib_def_.Find(fdef.signature().name());
772       if (prev_fdef == nullptr) {
773         return errors::Internal("Function: ", fdef.signature().name(),
774                                 " is in the cache but not in the library");
775       }
776       if (!FunctionDefsEqual(fdef, *prev_fdef)) {
777         return errors::InvalidArgument(
778             "Attempting to add a duplicate function with name: ",
779             fdef.signature().name(), " where the previous and current ",
780             "definitions differ. Previous definition: ",
781             prev_fdef->DebugString(),
782             " and current definition: ", fdef.DebugString());
783       }
784       registered_function->Ref();
785     }
786     is_first_ref = registered_function->RefCountIsOne();
787   }
788   if (is_first_ref) {
789     TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
790     TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
791     if (!add_to_local_only) {
792       return MaybeRegisterFunctionRemotely(fdef);
793     }
794   }
795   return Status::OK();
796 }
797 
GetFunctionDef(const string & function_name)798 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
799   return func_lib_def_.Find(function_name);
800 }
801 
ListFunctionNames()802 std::vector<string> EagerContext::ListFunctionNames() {
803   return func_lib_def_.ListFunctionNames();
804 }
805 
RemoveFunction(const string & func)806 Status EagerContext::RemoveFunction(const string& func) {
807   bool is_last_ref = false;
808   {
809     mutex_lock l(cache_mu_);
810     auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
811     if (registered_function == nullptr) {
812       return errors::InvalidArgument("Tried to remove non-existent function '",
813                                      func, "'.");
814     }
815     is_last_ref = registered_function->RefCountIsOne();
816     if (is_last_ref) {
817       for (auto& key : *registered_function->cached_kernel_keys) {
818         kernel_cache_.erase(key);
819       }
820       registered_functions_.erase(func);
821     }
822     registered_function->Unref();
823   }
824   if (is_last_ref) {
825     // TODO(fishx): Remove remote function as well.
826     return func_lib_def_.RemoveFunction(func);
827   }
828   return Status::OK();
829 }
830 
SyncExecutors()831 Status EagerContext::SyncExecutors() {
832   StatusGroup sg;
833   // Synchronize on context default executor
834   sg.Update(default_executor_.WaitForAllPendingNodes());
835   default_executor_.ClearError();
836 
837   // Synchronize thread local executors on client
838   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
839   {
840     mutex_lock l(executor_map_mu_);
841     executors_copy = thread_local_executor_;
842   }
843   for (const auto& entry : executors_copy) {
844     sg.Update(entry.second->WaitForAllPendingNodes());
845     entry.second->ClearError();
846   }
847 
848 #if !defined(IS_MOBILE_PLATFORM)
849   auto remote_contexts = GetRemoteContexts();
850   // Synchronize executors on remote workers
851   eager::EnqueueRequest request;
852   request.set_context_id(GetContextId());
853   request.add_queue()->mutable_sync_remote_executor_for_stream();
854   BlockingCounter counter(static_cast<int>(remote_contexts.size()));
855   std::vector<Status> statuses(remote_contexts.size());
856 
857   for (int i = 0; i < remote_contexts.size(); i++) {
858     const auto& target = remote_contexts[i];
859     core::RefCountPtr<eager::EagerClient> eager_client;
860     TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
861 
862     eager::EnqueueResponse* response = new eager::EnqueueResponse();
863     eager_client->StreamingEnqueueAsync(
864         /*call_opts=*/nullptr, &request, response,
865         [response, target, &counter, &s = statuses[i]](const Status& status) {
866           s = status;
867           delete response;
868           counter.DecrementCount();
869         });
870   }
871   counter.Wait();
872   for (const Status& s : statuses) {
873     sg.Update(s);
874   }
875 #endif  // !IS_MOBILE_PLATFORM
876   {
877     // Reset the global function rendezvous, which otherwise stores a failure
878     // state.
879     mutex_lock l(global_rendezvous_mu_);
880     global_rendezvous_for_functions_ =
881         core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
882   }
883   return sg.as_summary_status();
884 }
885 
GetCachedKernel(Fprint128 cache_key)886 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
887     Fprint128 cache_key) {
888   tf_shared_lock l(cache_mu_);
889   auto iter = kernel_cache_.find(cache_key);
890   if (iter == kernel_cache_.end()) {
891     return nullptr;
892   }
893   core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
894   new_ref->Ref();
895   return new_ref;
896 }
897 
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)898 void EagerContext::AddKernelToCache(Fprint128 cache_key,
899                                     KernelAndDevice* kernel) {
900   mutex_lock ml(cache_mu_);
901   core::RefCountPtr<KernelAndDevice> new_ref(kernel);
902   new_ref->Ref();
903   kernel_cache_[cache_key] = std::move(new_ref);
904   auto* registered_function =
905       gtl::FindPtrOrNull(registered_functions_, kernel->name());
906   // The kernel name can be either a primitive op or a function.
907   if (registered_function != nullptr) {
908     registered_function->cached_kernel_keys->emplace_back(cache_key);
909   }
910 }
911 
ShouldStoreGraphs()912 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
913 
SetShouldStoreGraphs(bool value)914 void EagerContext::SetShouldStoreGraphs(bool value) {
915   mutex_lock ml(metadata_mu_);
916   should_store_graphs_.store(value);
917   if (!value) {
918     run_metadata_.reset(new RunMetadata);
919   }
920 }
921 
FindDeviceFromName(const char * device_name,Device ** device) const922 Status EagerContext::FindDeviceFromName(const char* device_name,
923                                         Device** device) const {
924   *device = HostCPU();
925   if (device_name == nullptr || strlen(device_name) == 0) {
926     return Status::OK();
927   }
928 
929   auto status = local_device_mgr()->LookupDevice(device_name, device);
930   if (status.ok()) {
931     return status;
932   }
933 
934   if (remote_device_mgr() != nullptr) {
935     return remote_device_mgr()->LookupDevice(device_name, device);
936   }
937 
938   return status;
939 }
940 
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const941 Status EagerContext::FindCompositeDeviceFromName(
942     StringPiece device_name, CompositeDevice** device) const {
943   tf_shared_lock l(composite_devices_mu_);
944   for (const auto& d : composite_devices_) {
945     if (d.second->name() == device_name) {
946       *device = d.second.get();
947       return Status::OK();
948     }
949   }
950   return errors::NotFound("Unknown composite device: ", device_name);
951 }
952 
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)953 Status EagerContext::RegisterCustomDevice(
954     const string& device_name, std::unique_ptr<CustomDevice> device) {
955   Device* existing_physical_device = nullptr;
956   if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
957     return errors::AlreadyExists(device_name,
958                                  " already registered as a physical device.");
959   }
960   return custom_device_op_handler_.RegisterCustomDevice(device_name,
961                                                         std::move(device));
962 }
963 
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)964 Status EagerContext::FindOrCreateCompositeDevice(
965     const std::vector<string>& underlying_devices, const string& device_name,
966     CompositeDevice** composite_device) {
967   if (!device_name.empty() &&
968       FindCompositeDeviceFromName(device_name, composite_device).ok()) {
969     return Status::OK();
970   }
971 
972   const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
973 
974   mutex_lock l(composite_devices_mu_);
975   auto iter = composite_devices_.find(hash_key);
976   if (iter != composite_devices_.end()) {
977     *composite_device = iter->second.get();
978     return Status::OK();
979   }
980 
981   Status s;
982   std::unique_ptr<CompositeDevice> device;
983   if (device_name.empty()) {
984     // Create a CompositeDevice on the same task as the host CPU, in order to
985     // trigger packed TensorHandle copy from a client to a remote worker.
986     device = CompositeDevice::MakeDevice(underlying_devices,
987                                          composite_devices_.size(),
988                                          HostCPU()->parsed_name(), &s);
989   } else {
990     device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
991   }
992   TF_RETURN_IF_ERROR(s);
993   *composite_device = device.get();
994   pflr_->AddCompositeDevice(*composite_device);
995   composite_devices_.emplace(hash_key, std::move(device));
996   return Status::OK();
997 }
998 
OnSameTask(const Device * first,const Device * second) const999 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
1000   if (first == nullptr) first = HostCPU();
1001   if (second == nullptr) second = HostCPU();
1002   return first->parsed_name().job == second->parsed_name().job &&
1003          first->parsed_name().replica == second->parsed_name().replica &&
1004          first->parsed_name().task == second->parsed_name().task;
1005 }
1006 
1007 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const1008 Status EagerContext::CPUDeviceOnTask(const Device* device,
1009                                      Device** cpu_device) const {
1010   string cpu_device_name;
1011   TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
1012       device->name(), &cpu_device_name));
1013 
1014   return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
1015 }
1016 
ClearResourceContainer(const string & name)1017 void EagerContext::ClearResourceContainer(const string& name) {
1018   // TODO(b/139809335): This does not properly clean up remote resources
1019   auto local_devices = local_device_mgr()->ListDevices();
1020   for (Device* device : local_devices) {
1021     // Only ignore container not found errors.
1022     device->resource_manager()->Cleanup(name).IgnoreError();
1023   }
1024 }
1025 
1026 namespace {
GetTaskName(Device * d,string * task_name)1027 Status GetTaskName(Device* d, string* task_name) {
1028   string ignored;
1029   if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
1030     return errors::InvalidArgument("Unable to parse device name: ", d->name());
1031   }
1032 
1033   return Status::OK();
1034 }
1035 }  // namespace
1036 
1037 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)1038 Status EagerContext::GetClient(Device* device,
1039                                core::RefCountPtr<eager::EagerClient>* client) {
1040   return GetClient(device->parsed_name(), client);
1041 }
1042 
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)1043 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
1044                                core::RefCountPtr<eager::EagerClient>* client) {
1045   string device_task_name;
1046   if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1047     return errors::InvalidArgument(
1048         "Task is not fully specified in device name: ",
1049         DeviceNameUtils::ParsedNameToString(device_name));
1050   }
1051 
1052   {
1053     tf_shared_lock l(remote_state_mu_);
1054     if (remote_eager_workers_ == nullptr) {
1055       return errors::Internal(
1056           "Haven't set up remote eager worker in this eager context yet.");
1057     }
1058     TF_RETURN_IF_ERROR(
1059         remote_eager_workers_->GetClient(device_task_name, client));
1060 
1061     if (*client == nullptr) {
1062       return errors::InvalidArgument(
1063           "Unable to find eager client corresponding to device ",
1064           DeviceNameUtils::ParsedNameToString(device_name));
1065     }
1066     if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1067                   device_task_name) == remote_contexts_.end()) {
1068       return errors::Internal("Unable to find a context for handle on task: ",
1069                               device_task_name, ". This should not happen.");
1070     }
1071   }
1072 
1073   return Status::OK();
1074 }
1075 
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1076 Status EagerContext::GetClient(const string& remote_task,
1077                                core::RefCountPtr<eager::EagerClient>* client) {
1078   {
1079     tf_shared_lock l(remote_state_mu_);
1080     if (remote_eager_workers_ == nullptr) {
1081       return errors::Internal(
1082           "Haven't set up remote eager worker in this eager context yet.");
1083     }
1084     TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1085   }
1086 
1087   if (*client == nullptr) {
1088     return errors::InvalidArgument(
1089         "Unable to find eager client corresponding to target ", remote_task);
1090   }
1091   return Status::OK();
1092 }
1093 
GetContextId() const1094 uint64 EagerContext::GetContextId() const {
1095   tf_shared_lock l(remote_state_mu_);
1096   return context_id_;
1097 }
1098 
GetContextViewId() const1099 uint64 EagerContext::GetContextViewId() const {
1100   tf_shared_lock l(remote_state_mu_);
1101   return context_view_id_;
1102 }
1103 
IncrementContextViewId()1104 void EagerContext::IncrementContextViewId() {
1105   mutex_lock l(remote_state_mu_);
1106   context_view_id_ += 1;
1107 }
1108 
EnableCollectiveOps(const ServerDef & server_def)1109 Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) {
1110   return distributed_manager_->EnableCollectiveOps(server_def);
1111 }
1112 
1113 // Set collective ops related state in the context. Passing nullptr to
1114 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1115 Status EagerContext::StoreCollectiveOpsServer(
1116     std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
1117     CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1118   collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1119 
1120   if (device_mgr != local_device_manager_.Get()) {
1121     if (local_device_manager_.Owned()) {
1122       old_local_device_managers_.push_back(
1123           std::move(local_device_manager_.owned_object));
1124     }
1125     local_device_manager_.Reset(device_mgr);
1126     if (rendezvous_ != nullptr) rendezvous_->Unref();
1127     rendezvous_ = CreateRendezvous(-1);
1128   }
1129   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1130 
1131   InitPrioritizedDeviceTypeList();
1132   ClearCachesAndThreadExecutors();
1133   default_executor_.ClearError();
1134   {
1135     tensorflow::mutex_lock l(executor_map_mu_);
1136     for (auto& entry : thread_local_executor_) {
1137       entry.second->ClearError();
1138     }
1139   }
1140 
1141   const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1142   ResetPFLR(
1143       local_device_manager_.Get(), env_, /*config=*/config,
1144       TF_GRAPH_DEF_VERSION, &func_lib_def_,
1145       /*optimizer_options=*/
1146       config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1147       thread_pool_.get());
1148 
1149   if (new_server != nullptr) {
1150     // Memory leak!
1151     if (server_ != nullptr) {
1152       LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1153                       "Servers don't support clean shutdown.";
1154       server_.release();
1155     }
1156     server_ = std::move(new_server);
1157   }
1158   DCHECK(server_ != nullptr);
1159 
1160   return Status::OK();
1161 }
1162 
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1163 Status EagerContext::SetRemoteDeviceFilters(
1164     const string& remote_worker, const std::vector<string>& device_filters) {
1165   // Get fully specified task name for remote worker
1166   string remote_worker_task_name;
1167   DeviceNameUtils::ParsedName pw;
1168   if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1169     return tensorflow::errors::InvalidArgument(
1170         "Remote worker task name is invalid ", remote_worker);
1171   }
1172   // Force set a replica as the key in cluster device filters map. I.e., if the
1173   // remote worker is `/job:worker/task:0` it then becomes
1174   // `/job:worker/replica:0/task:0`.
1175   pw.has_replica = true;
1176   if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1177     return tensorflow::errors::InvalidArgument(
1178         "Job name and task index must be specified for worker ", remote_worker);
1179   }
1180 
1181   std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1182   for (auto& filter : device_filters) {
1183     DeviceNameUtils::ParsedName parsed_filter;
1184     if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1185       parsed_filters.emplace_back(parsed_filter);
1186     } else {
1187       return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1188     }
1189   }
1190 
1191   if (VLOG_IS_ON(1)) {
1192     VLOG(1) << "Setting device filters for " << remote_worker << ":";
1193     for (auto& filter : device_filters) {
1194       VLOG(1) << "  " << filter;
1195     }
1196   }
1197   mutex_lock l(remote_state_mu_);
1198   cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1199   return Status::OK();
1200 }
1201 
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1202 void EagerContext::FilterDevicesForRemoteWorkers(
1203     const string& remote_worker,
1204     const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1205     std::vector<bool>* filtered_device_mask) {
1206   filtered_device_mask->resize(device_attrs.size());
1207   std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1208 
1209   tf_shared_lock l(remote_state_mu_);
1210   auto it = cluster_device_filters_.find(remote_worker);
1211   // If no filters were specified, all devices should be visible to the worker
1212   if (it == cluster_device_filters_.end() || it->second.empty()) {
1213     std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1214     return;
1215   }
1216 
1217   const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1218   DeviceNameUtils::ParsedName parsed_remote_worker;
1219   DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1220   for (int i = 0; i < device_attrs.size(); i++) {
1221     DeviceNameUtils::ParsedName pn;
1222     DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1223     if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1224       // If this device is on the remote worker itself, it should be visible
1225       // regardless of device filters
1226       filtered_device_mask->at(i) = true;
1227       continue;
1228     }
1229     for (const auto& pf : parsed_filters) {
1230       if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1231           (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1232           (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1233           (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1234           (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1235         // Found a match, make it visible, stop processing more device filters
1236         filtered_device_mask->at(i) = true;
1237         break;
1238       }
1239     }
1240   }
1241 }
1242 
SetWorkerEnv(WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session)1243 void EagerContext::SetWorkerEnv(WorkerEnv* worker_env,
1244                                 std::shared_ptr<WorkerSession> worker_session) {
1245   worker_env_ = worker_env;
1246   worker_session_ = worker_session;
1247 }
1248 
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1249 Status EagerContext::InitializeRemoteMaster(
1250     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1251     std::shared_ptr<WorkerSession> worker_session,
1252     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1253     std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1254     const std::vector<string>& remote_contexts, uint64 context_id,
1255     Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
1256     DistributedFunctionLibraryRuntime* cluster_flr,
1257     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1258         remote_mgr) {
1259   if (context_id == kInvalidContextId) {
1260     return errors::InvalidArgument(
1261         "Failed to initialize remote for master context due to invalid ",
1262         "context id");
1263   }
1264 
1265   if (!IsRemoteContextsEmpty()) {
1266     CloseAndClearAllRemoteContexts();
1267   }
1268   {
1269     mutex_lock l(remote_state_mu_);
1270     remote_contexts_ = remote_contexts;
1271   }
1272 
1273   return SetMasterContextState(
1274       std::move(server), worker_env, std::move(worker_session),
1275       std::move(remote_eager_workers), std::move(remote_device_manager),
1276       context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1277       std::move(remote_mgr));
1278 }
1279 
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1280 Status EagerContext::UpdateRemoteMaster(
1281     uint64 context_id,
1282     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1283     const std::vector<string>& add_remote_contexts,
1284     const std::vector<string>& remove_remote_contexts) {
1285   {
1286     tf_shared_lock l(remote_state_mu_);
1287     if (context_id != context_id_) {
1288       return errors::InvalidArgument(
1289           "Failed to update remote master context due to invalid context id. ",
1290           "Request id = ", context_id, " but current id = ", context_id_);
1291     }
1292   }
1293 
1294   if (!remove_remote_contexts.empty()) {
1295     // N.B. remove_remote_contexts include both removed and replaced workers.
1296     // In the case where a worker is replaced by one that resolves to the same
1297     // `hostname:port`, it is safe to close context with the current view id,
1298     // since the newly created context on the remote worker will be holding
1299     // a larger view id and ignores this request.
1300     CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1301     mutex_lock l(remote_state_mu_);
1302     for (const string& remote_context : remove_remote_contexts) {
1303       remote_contexts_.erase(
1304           std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1305                       remote_context),
1306           remote_contexts_.end());
1307     }
1308   }
1309   if (!add_remote_contexts.empty()) {
1310     mutex_lock l(remote_state_mu_);
1311     remote_contexts_.insert(std::end(remote_contexts_),
1312                             std::begin(add_remote_contexts),
1313                             std::end(add_remote_contexts));
1314   }
1315 
1316   {
1317     mutex_lock l(remote_state_mu_);
1318     context_view_id_++;
1319 
1320     remote_eager_workers_ = std::move(remote_eager_workers);
1321     pflr_->InitializeDeviceAndFlr();
1322     InitPrioritizedDeviceTypeList();
1323 
1324     default_executor_.ClearError();
1325     {
1326       tensorflow::mutex_lock l(executor_map_mu_);
1327       for (auto& entry : thread_local_executor_) {
1328         entry.second->ClearError();
1329       }
1330     }
1331   }
1332 
1333   // Register existing functions to the newly added remote workers. Note that
1334   // this should happen only after updating `remote_contexts_` because new
1335   // functions might be registered while we update the context. When that
1336   // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1337   // register the new functions on all remote workers (including the newly added
1338   // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1339   // registering existing functions, where duplicate registrations will be
1340   // ignored by the remote workers.
1341   TF_RETURN_IF_ERROR(
1342       RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1343   return Status::OK();
1344 }
1345 
1346 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1347 Status EagerContext::SetMasterContextState(
1348     std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1349     std::shared_ptr<WorkerSession> worker_session,
1350     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1351     std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1352     uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1353     int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1354     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1355         remote_mgr) {
1356   mutex_lock l(remote_state_mu_);
1357   is_master_ = true;
1358   context_id_ = context_id;
1359   context_view_id_ = context_view_id;
1360 
1361   use_send_tensor_rpc_ =
1362       ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1363 
1364   if (local_device_mgr != local_device_manager_.Get()) {
1365     if (local_device_manager_.Owned()) {
1366       old_local_device_managers_.push_back(
1367           std::move(local_device_manager_.owned_object));
1368     }
1369     local_device_manager_.Reset(local_device_mgr);
1370   }
1371   host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1372 
1373   if (rendezvous_ != nullptr) rendezvous_->Unref();
1374   rendezvous_ = r;
1375 
1376   // Memory leak!
1377   if (server_ != nullptr) {
1378     LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1379                     "Servers don't support clean shutdown.";
1380     server_.release();
1381   }
1382   server_ = std::move(server);
1383 
1384   remote_mgr_ = std::move(remote_mgr);
1385   worker_env_ = worker_env;
1386   worker_session_ = std::move(worker_session);
1387   remote_eager_workers_ = std::move(remote_eager_workers);
1388 
1389   remote_device_manager_.Reset(std::move(remote_device_manager));
1390   ResetClusterFLR(cluster_flr);
1391 
1392   InitPrioritizedDeviceTypeList();
1393 
1394   ClearCachesAndThreadExecutors();
1395   default_executor_.ClearError();
1396   {
1397     tensorflow::mutex_lock l(executor_map_mu_);
1398     for (auto& entry : thread_local_executor_) {
1399       entry.second->ClearError();
1400     }
1401   }
1402   const auto* config = pflr_->config();
1403   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1404             &func_lib_def_, config->graph_options().optimizer_options(),
1405             thread_pool_.get(), cluster_flr_.Get());
1406 
1407   keep_alive_secs_ = keep_alive_secs;
1408   sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1409   // Only schedule a single closure.
1410   if (keep_alive_thread_ == nullptr) {
1411     keep_alive_thread_.reset(
1412         env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1413           while (true) {
1414             {
1415               {
1416                 mutex_lock l(keep_alive_thread_shutdown_mu_);
1417 
1418                 if (shutting_down_) {
1419                   return;
1420                 }
1421 
1422                 keep_alive_thread_cv_.wait_for(
1423                     l, std::chrono::seconds(sleep_for_secs_));
1424 
1425                 if (shutting_down_) {
1426                   return;
1427                 }
1428               }
1429               {
1430                 mutex_lock l(remote_state_mu_);
1431                 if (keep_alive_secs_ > 0) {
1432                   {
1433                     for (const auto& worker : remote_contexts_) {
1434                       core::RefCountPtr<eager::EagerClient> client;
1435                       Status s =
1436                           remote_eager_workers_->GetClient(worker, &client);
1437 
1438                       if (!s.ok()) {
1439                         LOG(WARNING) << "Keep-alive thread was unable to find "
1440                                         "a client for target "
1441                                      << worker << ". Got error: " << s;
1442                         continue;
1443                       }
1444 
1445                       eager::KeepAliveRequest* request =
1446                           new eager::KeepAliveRequest;
1447                       eager::KeepAliveResponse* response =
1448                           new eager::KeepAliveResponse;
1449 
1450                       request->set_context_id(context_id_);
1451                       client->KeepAliveAsync(
1452                           request, response,
1453                           [request, response](const Status& s) {
1454                             delete request;
1455                             delete response;
1456                           });
1457                     }
1458                   }
1459                 }
1460               }
1461             }
1462           }
1463         }));
1464   }
1465   return Status::OK();
1466 }
1467 
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64_t)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1468 Status EagerContext::InitializeRemoteWorker(
1469     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1470     DynamicDeviceMgr* remote_device_mgr,
1471     const std::vector<string>& remote_contexts, uint64 context_id,
1472     uint64 context_view_id,
1473     std::function<Rendezvous*(const int64_t)> rendezvous_creator,
1474     DistributedFunctionLibraryRuntime* cluster_flr,
1475     std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1476         remote_mgr,
1477     std::function<void()> resource_deallocator) {
1478   if (context_id == kInvalidContextId) {
1479     return errors::InvalidArgument(
1480         "Failed to initialize remote for worker context due to invalid ",
1481         "context id");
1482   }
1483   mutex_lock l(remote_state_mu_);
1484 
1485   if (remote_device_manager_.Owned() || server_ != nullptr ||
1486       keep_alive_thread_ != nullptr) {
1487     return errors::FailedPrecondition(
1488         "EagerContext::InitializeRemoteWorker Failed. ",
1489         "Already initialized remote as a master context.");
1490   }
1491   is_master_ = false;
1492 
1493   remote_contexts_ = remote_contexts;
1494   context_id_ = context_id;
1495   context_view_id_ = context_view_id;
1496 
1497   rendezvous_creator_ = std::move(rendezvous_creator);
1498   remote_eager_workers_ = std::move(remote_eager_workers);
1499   remote_mgr_ = std::move(remote_mgr);
1500   ResetClusterFLR(cluster_flr);
1501 
1502   remote_device_manager_.Reset(remote_device_mgr);
1503 
1504   const auto* config = pflr_->config();
1505   ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1506             &func_lib_def_, config->graph_options().optimizer_options(),
1507             thread_pool_.get(), cluster_flr_.Get());
1508   InitPrioritizedDeviceTypeList();
1509 
1510   ClearCachesAndThreadExecutors();
1511   default_executor_.ClearError();
1512   {
1513     tensorflow::mutex_lock l(executor_map_mu_);
1514     for (auto& entry : thread_local_executor_) {
1515       entry.second->ClearError();
1516     }
1517   }
1518 
1519   resource_deallocator_ = std::move(resource_deallocator);
1520 
1521   return Status::OK();
1522 }
1523 
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1524 Status EagerContext::UpdateRemoteWorker(
1525     std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1526     const std::vector<string>& remote_contexts, uint64 context_id) {
1527   {
1528     mutex_lock l(remote_state_mu_);
1529     if (context_id != context_id_) {
1530       return errors::InvalidArgument(
1531           "Failed to update remote for worker context due to invalid ",
1532           "context id. Request id = ", context_id,
1533           " but current id = ", context_id_);
1534     }
1535     context_view_id_++;
1536 
1537     remote_contexts_ = remote_contexts;
1538     remote_eager_workers_ = std::move(remote_eager_workers);
1539     InitPrioritizedDeviceTypeList();
1540     pflr_->InitializeDeviceAndFlr();
1541   }
1542 
1543   // No need to update remote_device_manager_ since it's not owned for remote
1544   // worker context (owned by the corresponding worker session).
1545   if (remote_device_manager_.Owned()) {
1546     return errors::FailedPrecondition(
1547         "EagerContext::UpdateRemoteWorker failed because the context was "
1548         "initialized as a master context.");
1549   }
1550 
1551   ClearCachesAndThreadExecutors();
1552   default_executor_.ClearError();
1553   {
1554     tensorflow::mutex_lock l(executor_map_mu_);
1555     for (auto& entry : thread_local_executor_) {
1556       entry.second->ClearError();
1557     }
1558   }
1559   return Status::OK();
1560 }
1561 #endif  // !IS_MOBILE_PLATFORM
1562 
1563 }  // namespace tensorflow
1564