1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "tensorflow/c/eager/immediate_execution_context.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/device_attributes.pb.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/nccl/collective_communicator.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/platform.h"
37 // clang-format on
38
39 #include "tensorflow/c/tf_tensor.h"
40 #include "tensorflow/c/tf_tensor_internal.h"
41 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
42 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
43 #include "tensorflow/core/common_runtime/colocation_graph.h"
44 #include "tensorflow/core/common_runtime/device_resolver_local.h"
45 #include "tensorflow/core/common_runtime/device_set.h"
46 #include "tensorflow/core/common_runtime/process_util.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/graph_def_util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/protobuf/config.pb.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/device_name_utils.h"
53 #if !defined(IS_MOBILE_PLATFORM)
54 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
55 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
56 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
57 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
58 #include "tensorflow/core/distributed_runtime/session_mgr.h"
59 #endif // !IS_MOBILE_PLATFORM
60 #include "tensorflow/core/framework/resource_mgr.h"
61 #include "tensorflow/core/lib/core/blocking_counter.h"
62 #include "tensorflow/core/lib/monitoring/gauge.h"
63 #include "tensorflow/core/util/env_var.h"
64
65 namespace tensorflow {
66
67 namespace {
68 // This object tracks the EagerContext owned by global_py_eager_context in
69 // pywrap_tfe_src.cc. Since the vast majority of the Python API is dependent on
70 // that global_py_eager_context (including memory management), the Py object
71 // owns the C object, so this pointer is non-owning.
72 EagerContext* global_c_eager_context = nullptr;
73
74 } // namespace
75
SetCEagerContext(EagerContext * ctx)76 void SetCEagerContext(EagerContext* ctx) { global_c_eager_context = ctx; }
77
GetCEagerContext()78 EagerContext* GetCEagerContext() { return global_c_eager_context; }
79
80 namespace {
81
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)82 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
83 bool val;
84 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
85 return val;
86 }
87 return default_val;
88 }
89
90 auto* eager_context_created =
91 monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
92 "True if an eager context was created.");
93
94 } // namespace
95
96 const int64_t EagerContext::kGlobalRendezvousId = -1;
97
98 // Find the rendezvous instance corresponding to the step id, or create a
99 // new instance if not existing.
FindOrCreate(int64_t step_id,DeviceMgr * device_mgr)100 IntraProcessRendezvous* EagerContext::LocalRendezvousTable::FindOrCreate(
101 int64_t step_id, DeviceMgr* device_mgr) {
102 mutex_lock l(table_lock_);
103 auto iter = table_.find(step_id);
104 if (iter == table_.end()) {
105 iter =
106 table_.insert({step_id, new IntraProcessRendezvous(device_mgr)}).first;
107 // Global rendezvous: ref-count should be 1 upon creation.
108 if (step_id == EagerContext::kGlobalRendezvousId) {
109 return iter->second;
110 }
111 }
112 iter->second->Ref();
113 return iter->second;
114 }
115
Find(int64_t step_id)116 IntraProcessRendezvous* EagerContext::LocalRendezvousTable::Find(
117 int64_t step_id) {
118 mutex_lock l(table_lock_);
119 auto iter = table_.find(step_id);
120 if (iter == table_.end()) return nullptr;
121 iter->second->Ref();
122 return iter->second;
123 }
124
Remove(int64_t step_id)125 void EagerContext::LocalRendezvousTable::Remove(int64_t step_id) {
126 mutex_lock l(table_lock_);
127 auto iter = table_.find(step_id);
128 if (iter != table_.end()) {
129 table_.erase(iter);
130 }
131 }
132
CleanUpAll()133 void EagerContext::LocalRendezvousTable::CleanUpAll() {
134 mutex_lock l(table_lock_);
135 for (auto iter = table_.begin(); iter != table_.end(); iter++) {
136 // Unref all redezvous instance, except for global rendezvous,
137 // which is cleaned up elsewhere when necessary.
138 if (iter->first == -1) {
139 continue;
140 }
141 iter->second->Unref();
142 }
143 }
144
~LocalRendezvousTable()145 EagerContext::LocalRendezvousTable::~LocalRendezvousTable() { CleanUpAll(); }
146
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr,CollectiveExecutorMgrInterface * collective_executor_mgr,bool run_eager_op_as_function,bool jit_compile_rewrite)147 EagerContext::EagerContext(
148 const SessionOptions& opts,
149 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
150 DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
151 DistributedFunctionLibraryRuntime* cluster_flr,
152 CollectiveExecutorMgrInterface* collective_executor_mgr,
153 bool run_eager_op_as_function, bool jit_compile_rewrite)
154 : ImmediateExecutionContext(kEager),
155 opts_(opts),
156 default_device_placement_policy_(default_device_placement_policy),
157 local_device_manager_(device_mgr, device_mgr_owned),
158 host_cpu_device_(device_mgr->HostCPU()),
159 rendezvous_(rendezvous),
160 thread_pool_(NewThreadPoolFromSessionOptions(opts)),
161 cluster_flr_(cluster_flr),
162 log_device_placement_(opts.config.log_device_placement()),
163 allow_soft_placement_(opts.config.allow_soft_placement()),
164 num_active_steps_(0),
165 step_container_(std::make_unique<ScopedStepContainer>(
166 0, [this](const string& name) { ClearResourceContainer(name); })),
167 default_executor_(async, /*enable_streaming_enqueue=*/true),
168 log_memory_(LogMemory::IsEnabled()),
169 env_(opts.env),
170 collective_executor_mgr_(collective_executor_mgr, /*owned=*/false),
171 use_send_tensor_rpc_(false),
172 pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
173 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)),
174 run_eager_op_as_function_(run_eager_op_as_function),
175 jit_compile_rewrite_(jit_compile_rewrite) {
176 ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
177 &func_lib_def_, opts.config.graph_options().optimizer_options(),
178 thread_pool_.get(), cluster_flr);
179 // Starts exporting metrics through a platform-specific monitoring API (if
180 // provided). For builds using "tensorflow/tsl/platform/default", this is
181 // currently a no-op.
182 eager_context_created->GetCell()->Set(true);
183 InitPrioritizedDeviceTypeList();
__anon3eb91ef90402(std::function<void()> closure) 184 runner_ = [this](std::function<void()> closure) {
185 this->thread_pool_->Schedule(std::move(closure));
186 };
187
188 run_metadata_ = std::make_unique<RunMetadata>();
189
190 #if !defined(IS_MOBILE_PLATFORM)
191 context_id_ = kInvalidContextId;
192 context_view_id_ = 0;
193 #endif // IS_MOBILE_PLATFORM
194
195 // TODO(yuefengz): consider creating a new RpcCollectiveExecutorMgr each
196 // time.
197 if (collective_executor_mgr_.Get() == nullptr) {
198 collective_executor_mgr_.Reset(CreateProdLocalCollectiveExecutorMgr(
199 opts.config, local_device_mgr(),
200 MaybeCreateNcclCommunicator(opts.config)));
201 }
202
203 // Initialization of local_rendezvous_table_ needs to happen before the
204 // initialization of global_rendezvous_for_functions_ because the latter
205 // depends on the former.
206 local_rendezvous_table_ = std::make_unique<LocalRendezvousTable>();
207 global_rendezvous_for_functions_ =
208 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
209 }
210
CreateInt64Scalar(int64_t value)211 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64_t value) {
212 return new TensorInterface(Tensor(value));
213 }
214
CreateUint64Scalar(uint64 value)215 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
216 return new TensorInterface(Tensor(value));
217 }
218
CreateInt32Scalar(int32_t value)219 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32_t value) {
220 return new TensorInterface(Tensor(value));
221 }
222
CreateFloatScalar(float value)223 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
224 return new TensorInterface(Tensor(value));
225 }
226
CreateDoubleScalar(double value)227 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
228 return new TensorInterface(Tensor(value));
229 }
230
CreateHalfScalar(Eigen::half value)231 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
232 return new TensorInterface(Tensor(value));
233 }
234
CreateStringScalar(tstring value)235 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
236 return new TensorInterface(Tensor(value));
237 }
238
CreateComplex128Scalar(complex128 value)239 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
240 complex128 value) {
241 return new TensorInterface(Tensor(value));
242 }
243
CreateBoolScalar(bool value)244 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
245 return new TensorInterface(Tensor(value));
246 }
247
CreateTensor(DataType dtype,absl::Span<const int64_t> dim_sizes)248 AbstractTensorInterface* EagerContext::CreateTensor(
249 DataType dtype, absl::Span<const int64_t> dim_sizes) {
250 return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
251 }
252
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)253 AbstractTensorInterface* EagerContext::CreateTensor(
254 DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
255 MemoryReleaser memory_releaser, void* memory_releaser_arg) {
256 TF_Tensor* tensor_wrapper =
257 TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
258 memory_releaser, memory_releaser_arg);
259
260 AbstractTensorInterface* result = nullptr;
261 std::swap(result, tensor_wrapper->tensor);
262 TF_DeleteTensor(tensor_wrapper);
263 return result;
264 }
265
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)266 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
267 const ConfigProto* config, int graph_def_version,
268 const FunctionLibraryDefinition* lib_def,
269 const OptimizerOptions& optimizer_options,
270 thread::ThreadPool* thread_pool,
271 DistributedFunctionLibraryRuntime* cluster_flr) {
272 Rendezvous::Factory rendezvous_factory{
273 [this](const int64_t step_id, const DeviceMgr*, Rendezvous** r) {
274 *r = CreateRendezvous(step_id);
275 return OkStatus();
276 }};
277 pflr_.reset(new ProcessFunctionLibraryRuntime(
278 device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
279 thread_pool, cluster_flr,
280 /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
281 }
282
InitPrioritizedDeviceTypeList()283 void EagerContext::InitPrioritizedDeviceTypeList() {
284 DeviceSet ds;
285 for (Device* d : local_device_mgr()->ListDevices()) {
286 ds.AddDevice(d);
287 }
288 auto remote_device_manager = remote_device_mgr();
289 if (remote_device_manager != nullptr) {
290 for (Device* d : remote_device_manager->ListDevices()) {
291 ds.AddDevice(d);
292 }
293 }
294 mutex_lock l(device_type_list_mu_);
295 prioritized_device_type_list_ =
296 std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
297 }
298
299 namespace {
300 // Using absl::StrJoin with lambda does not work in tf-lite builds.
301 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)302 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
303 std::vector<string> v;
304 v.reserve(devices.size());
305 for (const auto& p : devices) {
306 v.push_back(p.first->name());
307 }
308 return v;
309 }
310
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)311 std::vector<string> DeviceTypesToString(
312 const PrioritizedDeviceTypeVector& types) {
313 std::vector<string> v;
314 v.reserve(types.size());
315 for (const auto& p : types) {
316 v.push_back(p.first.type_string());
317 }
318 return v;
319 }
320
321 // Selects the "best" device that both exists and is supported.
322 //
323 // The `existing` argument specifies the available devices in the system, in
324 // priority order. The `supported` argument specifies the supported device types
325 // and their priorities, lower index types having higher priority.
326 // Currently the type priority defined by the `supported` parameter takes
327 // precedence over system device priorities from `existing`.
328 //
329 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)330 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
331 const PrioritizedDeviceVector& existing,
332 const PrioritizedDeviceTypeVector& supported) {
333 for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
334 for (const std::pair<Device*, int32>& prioritized_device : existing) {
335 Device* dev = prioritized_device.first;
336 if (DeviceType(dev->attributes().device_type()) ==
337 prioritized_type.first &&
338 DeviceNameUtils::IsCompleteSpecification(pattern,
339 dev->parsed_name())) {
340 return dev;
341 }
342 }
343 }
344 return nullptr;
345 }
346
347 } // namespace
348
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const349 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
350 const NodeDef& ndef, Device** out) const {
351 DCHECK(out != nullptr);
352
353 PrioritizedDeviceTypeVector supported_devs;
354 auto device_type_list = prioritized_device_type_list();
355 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
356 *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
357 if (supported_devs.empty()) {
358 return errors::NotFound("Could not find device for node: ",
359 errors::FormatNodeNameForError(ndef.name()), " = ",
360 ndef.op(), "[", SummarizeAttrs(ndef), "]",
361 "\nAll kernels registered for op ", ndef.op(),
362 ":\n", KernelsRegisteredForOp(ndef.op()));
363 }
364
365 // Select the first matching registered device from the supported device
366 // list. If nothing matches and soft placement is enabled, pick a suitable
367 // device from the available ones.
368 const auto pflr_device_set = pflr()->device_set();
369 const PrioritizedDeviceVector& existing =
370 pflr_device_set->prioritized_devices();
371 *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
372 if (*out != nullptr) {
373 return OkStatus();
374 }
375
376 if (AllowSoftPlacement()) {
377 DeviceNameUtils::ParsedName soft_device_name = preferred;
378 soft_device_name.type.clear();
379 soft_device_name.has_type = false;
380 soft_device_name.has_id = false;
381 // TODO(b/148213746): Soft placement logic picks up another task if the
382 // requested does not exist.
383 *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
384 if (*out != nullptr) {
385 return OkStatus();
386 }
387 }
388
389 if (DeviceNameUtils::HasSomeDetails(preferred)) {
390 return errors::InvalidArgument(
391 "Could not satisfy device specification '", preferred,
392 "'. enable_soft_placement=", AllowSoftPlacement(),
393 ". Supported device types [",
394 absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
395 "]. All available devices [",
396 absl::StrJoin(DevicesToString(existing), ", "), "].");
397 }
398 return errors::InvalidArgument(
399 "No supported device found in available devices [",
400 absl::StrJoin(DevicesToString(existing), ", "),
401 "]. enable_soft_placement=", AllowSoftPlacement(),
402 ". Supported devices types [",
403 absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
404 }
405
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)406 void EagerContext::ResetClusterFLR(
407 DistributedFunctionLibraryRuntime* cluster_flr) {
408 cluster_flr_.Reset(cluster_flr, /*owned=*/true);
409 }
410
UpdateClusterFLRAndInitDevices(DistributedFunctionLibraryRuntime * cluster_flr)411 void EagerContext::UpdateClusterFLRAndInitDevices(
412 DistributedFunctionLibraryRuntime* cluster_flr) {
413 ResetClusterFLR(cluster_flr);
414
415 const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
416 ResetPFLR(
417 local_device_manager_.Get(), env_, /*config=*/config,
418 TF_GRAPH_DEF_VERSION, &func_lib_def_,
419 /*optimizer_options=*/
420 config ? config->graph_options().optimizer_options() : OptimizerOptions(),
421 thread_pool_.get(), cluster_flr_.Get());
422 }
423
Executor()424 EagerExecutor& EagerContext::Executor() {
425 tf_shared_lock l(executor_map_mu_);
426 return *gtl::FindWithDefault(thread_local_executor_,
427 std::this_thread::get_id(), &default_executor_);
428 }
429
SetExecutorForThread(EagerExecutor * executor)430 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
431 tensorflow::mutex_lock l(executor_map_mu_);
432 if (executor == &default_executor_) {
433 thread_local_executor_.erase(std::this_thread::get_id());
434 } else {
435 auto thread_id = std::this_thread::get_id();
436 thread_local_executor_[thread_id] = executor;
437 auto& executors_with_cleanups = has_cleanup_[thread_id];
438 if (executors_with_cleanups.find(executor) ==
439 executors_with_cleanups.end()) {
440 executors_with_cleanups.insert(executor);
441 // If the executor is deleted before this context, we need to remove it
442 // from the map to avoid attempting to sync it in our destructor.
443 std::function<void()> cleanup([this, thread_id, executor]() {
444 {
445 tensorflow::mutex_lock l(executor_map_mu_);
446 auto existing = thread_local_executor_.find(thread_id);
447 if (existing != thread_local_executor_.end() &&
448 existing->second == executor) {
449 thread_local_executor_.erase(thread_id);
450 }
451 has_cleanup_[thread_id].erase(executor);
452 // Clears the global rendezvous after cleaning up the executor. This
453 // is needed when running in eager op as function mode because it
454 // re-uses the EagerContext's global_rendezvous_for_functions. The
455 // global rendezvous can end up in a bad state if any op ends in a
456 // bad state after execution.
457 if (!GetGlobalRendezvousForFunctionLocalRendezvousStatus().ok()) {
458 VLOG(6) << "global_rendezvous_for_functions_ is in bad state. "
459 "Resetting.";
460 ResetGlobalRendezvousForFunction();
461 }
462 }
463 });
464 executor->AddCleanup(reinterpret_cast<intptr_t>(this),
465 std::move(cleanup));
466 }
467 }
468 }
469
ClearCachesAndThreadExecutors()470 void EagerContext::ClearCachesAndThreadExecutors() {
471 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
472 {
473 mutex_lock l(executor_map_mu_);
474 executors_copy = thread_local_executor_;
475 }
476 for (const auto& entry : executors_copy) {
477 entry.second->WaitForAllPendingNodes().IgnoreError();
478 }
479 ClearCachesAndDefaultExecutor();
480 }
481
ClearCachesAndDefaultExecutor()482 void EagerContext::ClearCachesAndDefaultExecutor() {
483 // The executor stores pointers to kernels, so we need to make sure that no
484 // async eager ops are still executing. We lock the cache during this time
485 // as well.
486 mutex_lock ml(cache_mu_);
487 default_executor_.WaitForAllPendingNodes().IgnoreError();
488 kernel_cache_.clear();
489 for (auto& entry : registered_functions_) {
490 entry.second->cached_kernel_keys->clear();
491 }
492 {
493 mutex_lock dl(device_cache_mu_);
494 device_cache_.clear();
495 }
496 {
497 mutex_lock ml(metadata_mu_);
498 step_container_.reset(new ScopedStepContainer(
499 0, [this](const string& name) { ClearResourceContainer(name); }));
500 }
501 }
502
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)503 void EagerContext::SetThreadLocalDevicePlacementPolicy(
504 ContextDevicePlacementPolicy policy) {
505 mutex_lock ml(policy_map_mu_);
506 VLOG(6) << "Setting device placement policy to: " << policy;
507 device_placement_policy_[std::this_thread::get_id()] = policy;
508 }
509
GetDevicePlacementPolicy() const510 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
511 tf_shared_lock l(policy_map_mu_);
512 auto policy_map_it =
513 device_placement_policy_.find(std::this_thread::get_id());
514 if (policy_map_it != device_placement_policy_.end()) {
515 VLOG(6) << "ContextDevicePlacementPolicy: " << policy_map_it->second;
516 return policy_map_it->second;
517 }
518 VLOG(6) << "ContextDevicePlacementPolicy not found; returning default.";
519 return default_device_placement_policy_;
520 }
521
522 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()523 std::vector<string> EagerContext::GetRemoteContexts() {
524 tf_shared_lock l(remote_state_mu_);
525 return remote_contexts_;
526 }
527
IsRemoteContextsEmpty()528 bool EagerContext::IsRemoteContextsEmpty() {
529 tf_shared_lock l(remote_state_mu_);
530 return remote_contexts_.empty();
531 }
532
CloseAndClearAllRemoteContexts()533 void EagerContext::CloseAndClearAllRemoteContexts() {
534 uint64 context_id;
535 uint64 context_view_id;
536 std::vector<string> remote_contexts_copy;
537 {
538 mutex_lock l(remote_state_mu_);
539 if (!is_master_) return;
540 context_id = context_id_;
541 context_view_id = context_view_id_;
542 context_id_ = kInvalidContextId;
543 // Forget the current view id and reset to the starting value 0.
544 context_view_id_ = 0;
545
546 // Make a copy of remote targets to avoid holding the lock when sending
547 // close context requests.
548 remote_contexts_copy = remote_contexts_;
549 remote_contexts_.clear();
550 }
551 CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
552 }
553
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)554 void EagerContext::CloseRemoteContexts(
555 const std::vector<string>& remote_contexts, uint64 context_id,
556 uint64 context_view_id) {
557 // Close all remote contexts.
558 eager::CloseContextRequest request;
559 request.set_context_id(context_id);
560 request.set_context_view_id(context_view_id);
561 // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
562 // request to closed remote workers.
563 std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
564 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
565
566 int i = 0;
567 for (const auto& worker : remote_contexts) {
568 core::RefCountPtr<eager::EagerClient> client;
569 Status s = GetClient(worker, &client);
570
571 client->CloseContextAsync(
572 &request, &responses[i],
573 [&worker, &counter, context_id](const Status& s) {
574 if (!s.ok()) {
575 LOG(ERROR) << "Unable to close remote context with ID "
576 << context_id << " for worker: " << worker << " due to "
577 << s.error_message();
578 }
579 counter.DecrementCount();
580 });
581 i++;
582 }
583
584 counter.Wait();
585 }
586
587 #endif // !IS_MOBILE_PLATFORM
588
WaitForAndCloseRemoteContexts()589 void EagerContext::WaitForAndCloseRemoteContexts() {
590 ClearCachesAndThreadExecutors();
591
592 #if !defined(IS_MOBILE_PLATFORM)
593 {
594 mutex_lock l(keep_alive_thread_shutdown_mu_);
595 shutting_down_ = true;
596 keep_alive_thread_cv_.notify_all();
597 }
598 keep_alive_thread_.reset();
599
600 if (!IsRemoteContextsEmpty()) {
601 CloseAndClearAllRemoteContexts();
602 }
603
604 {
605 mutex_lock l(remote_state_mu_);
606
607 default_executor_.ShutDown().IgnoreError();
608 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
609 {
610 mutex_lock l(executor_map_mu_);
611 executors_copy = thread_local_executor_;
612 }
613 for (const auto& it : executors_copy) {
614 it.second->ShutDown().IgnoreError();
615 }
616
617 // This shuts down the completion queue and joins the thread polling it.
618 // The thread exits only after the completion queue has been drained of all
619 // the events. These events' completion should invoke all remaining RPC
620 // callbacks.
621 // This also deletes all EagerClient instances. There should not be any
622 // references to EagerClients left after all RPCs and async ops have been
623 // finished.
624 remote_eager_workers_ = nullptr;
625 }
626 #endif // !IS_MOBILE_PLATFORM
627 }
628
~EagerContext()629 EagerContext::~EagerContext() {
630 // TODO(iga): Add a separate API method to shutdown EagerContext so that we
631 // don't send RPCs and block in destructor.
632 WaitForAndCloseRemoteContexts();
633
634 // Custom devices may have obtained references to various context components
635 // (executors, thread pool). It's safer to run their destructors early.
636 custom_device_op_handler_.Clear();
637
638 ClearCachesAndThreadExecutors();
639 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
640 {
641 mutex_lock l(executor_map_mu_);
642 executors_copy = thread_local_executor_;
643 }
644 for (const auto& entry : executors_copy) {
645 // Let the executor know that its cleanup closure is no longer valid.
646 entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
647 }
648 for (auto& entry : registered_functions_) {
649 while (!entry.second->Unref()) {
650 // remove all references.
651 }
652 }
653 registered_functions_.clear();
654
655 #if !defined(IS_MOBILE_PLATFORM)
656 if (server_) {
657 // TODO(b/136478427): Fix this.
658 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
659 "Servers don't support clean shutdown.";
660 // TODO(hanyangtay): Remove this teardown logic once gRPC server clean
661 // shutdown is supported.
662 if (server_->worker_env()->session_mgr != nullptr) {
663 // Tear down coordination service.
664 Status s = server_->StopCoordinationService();
665 if (!s.ok()) {
666 LOG(ERROR) << "Failed to stop coordination service: " << s;
667 }
668 }
669 server_.release();
670 }
671
672 {
673 mutex_lock l(keep_alive_thread_shutdown_mu_);
674 shutting_down_ = true;
675 keep_alive_thread_cv_.notify_all();
676 }
677 keep_alive_thread_.reset();
678 if (!remote_contexts_.empty()) {
679 CloseAndClearAllRemoteContexts();
680 }
681
682 // Clean up all the rendezvous instances created via EagerContext.
683 // Currently there are 3 cases in which a rendezvous instances is created:
684 // (1). Created through a rendezvous_creator passed to EagerContext.
685 // (2). Created through rendezvous_mgr.
686 // (3). Created within EagerContext using LocalRendezvousTable.
687 //
688 // Currently case-(3) is taken care of automatically when an EagerContext
689 // instance is deleted. The following code takes care of case-(2). Case-(1)
690 // is tricky as EagerContext does not have a way to access those rendezvous
691 // instances.
692 // TODO (tfrt-dev): Take care of case-(1) mentioned above.
693 if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) {
694 worker_env_->rendezvous_mgr->CleanupAll();
695 }
696 #endif // !IS_MOBILE_PLATFORM
697
698 if (rendezvous_) {
699 rendezvous_->Unref();
700 }
701 if (resource_deallocator_ != nullptr) {
702 resource_deallocator_();
703 }
704 }
705
FindFunctionByName(const string & name) const706 bool EagerContext::FindFunctionByName(const string& name) const {
707 return func_lib_def_.Find(name) != nullptr;
708 }
709
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)710 Status EagerContext::FindFunctionOpData(
711 const string& name, const tensorflow::OpRegistrationData** op_data) {
712 return func_lib_def_.LookUp(name, op_data);
713 }
714
FindFunctionDef(const string & name) const715 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
716 return func_lib_def_.Find(name);
717 }
718
ExportRunMetadata()719 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
720 mutex_lock ml(metadata_mu_);
721 auto result = std::make_unique<RunMetadata>();
722 run_metadata_.swap(result);
723 return result;
724 }
725
UsesTFRT()726 bool EagerContext::UsesTFRT() { return false; }
727
RunEagerOpAsFunction() const728 bool EagerContext::RunEagerOpAsFunction() const {
729 VLOG(3) << "RunEagerOpAsFunction: " << run_eager_op_as_function_;
730 return run_eager_op_as_function_;
731 }
732
SetRunEagerOpAsFunction(bool enable)733 void EagerContext::SetRunEagerOpAsFunction(bool enable) {
734 run_eager_op_as_function_ = enable;
735 }
736
JitCompileRewrite() const737 bool EagerContext::JitCompileRewrite() const {
738 VLOG(3) << "JitCompileRewrite: " << jit_compile_rewrite_;
739 return jit_compile_rewrite_;
740 }
741
SetJitCompileRewrite(bool enable)742 void EagerContext::SetJitCompileRewrite(bool enable) {
743 jit_compile_rewrite_ = enable;
744 }
745
ListDevices(std::vector<tensorflow::DeviceAttributes> * device_attributes)746 void EagerContext::ListDevices(
747 std::vector<tensorflow::DeviceAttributes>* device_attributes) {
748 std::vector<Device*> devices = ListAllTfDevices();
749 device_attributes->reserve(devices.size());
750 for (const auto& dev : devices) {
751 device_attributes->emplace_back(dev->attributes());
752 }
753 }
754
ListAllTfDevices()755 std::vector<Device*> EagerContext::ListAllTfDevices() {
756 // Since remote_device_mgr may also contain local devices, make sure no
757 // duplicated device is returned.
758 std::vector<Device*> devices;
759 std::unordered_set<string> dev_names;
760
761 if (local_device_mgr()) {
762 for (const auto& dev : local_device_mgr()->ListDevices()) {
763 devices.emplace_back(dev);
764 dev_names.emplace(dev->attributes().name());
765 }
766 }
767
768 // TODO (b/197281777): Include local devices in remote_device_mgr on the
769 // client-side in single-client deployment.
770 if (remote_device_mgr()) {
771 for (const auto& dev : remote_device_mgr()->ListDevices()) {
772 Device* device = nullptr;
773 if (local_device_mgr()->LookupDevice(dev->name(), &device) !=
774 OkStatus()) {
775 // Include this device from remote_device_mgr only if it does not exist
776 // in local_device_mgr.
777 devices.emplace_back(dev);
778 }
779 }
780 }
781
782 return devices;
783 }
784
AddDevices(std::vector<std::unique_ptr<Device>> devices)785 Status EagerContext::AddDevices(std::vector<std::unique_ptr<Device>> devices) {
786 std::vector<std::unique_ptr<Device>> local_devices, remote_devices;
787 while (!devices.empty()) {
788 if (devices.front()->IsLocal()) {
789 local_devices.push_back(std::move(devices.front()));
790 } else {
791 remote_devices.push_back(std::move(devices.front()));
792 }
793 devices.erase(devices.begin());
794 }
795 TF_RETURN_IF_ERROR(
796 reinterpret_cast<DynamicDeviceMgr*>(local_device_manager_.Get())
797 ->AddDevices(std::move(local_devices)));
798
799 if (!remote_devices.empty()) {
800 if (!remote_device_mgr()) {
801 remote_device_manager_.Reset(
802 std::make_unique<tensorflow::DynamicDeviceMgr>());
803 }
804 TF_RETURN_IF_ERROR(
805 reinterpret_cast<DynamicDeviceMgr*>(remote_device_manager_.Get())
806 ->AddDevices(std::move(remote_devices)));
807 }
808
809 // Add the devices to pflr's device set.
810 pflr_->InitializeDeviceAndFlr();
811 InitPrioritizedDeviceTypeList();
812 return OkStatus();
813 }
814
StartStep()815 void EagerContext::StartStep() {
816 mutex_lock ml(metadata_mu_);
817 num_active_steps_++;
818 }
819
EndStep()820 void EagerContext::EndStep() {
821 mutex_lock ml(metadata_mu_);
822 num_active_steps_--;
823 if (num_active_steps_ == 0) {
824 // TODO(b/139809335): This does not properly clean up remote resources
825 // Clean up the previous step container and create a new one.
826 step_container_.reset(new ScopedStepContainer(
827 0, [this](const string& name) { ClearResourceContainer(name); }));
828 }
829 }
830
StepContainer()831 ScopedStepContainer* EagerContext::StepContainer() {
832 mutex_lock ml(metadata_mu_);
833 return step_container_.get();
834 }
835
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)836 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
837 // Only client context can register function on remote worker context.
838 if (!remote_device_manager_.Owned()) return OkStatus();
839 #if !defined(IS_MOBILE_PLATFORM)
840 std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
841 request->set_context_id(GetContextId());
842
843 eager::RegisterFunctionOp* register_function =
844 request->add_queue()->mutable_register_function();
845 *register_function->mutable_function_def() = fdef;
846 StripDefaultAttributes(
847 *OpRegistry::Global(),
848 register_function->mutable_function_def()->mutable_node_def());
849
850 auto remote_contexts = GetRemoteContexts();
851 for (const auto& target : remote_contexts) {
852 core::RefCountPtr<eager::EagerClient> eager_client;
853 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
854
855 eager::EnqueueResponse* response = new eager::EnqueueResponse();
856 eager_client->StreamingEnqueueAsync(
857 this->Executor().StreamingEnqueue(),
858 /*call_opts=*/nullptr, request.get(), response,
859 [request, response](const Status& status) {
860 if (!status.ok()) {
861 LOG(ERROR) << "Failed to register function remotely due to "
862 << status.error_message()
863 << "\nThis could happen if the remote target has been "
864 "disconnected from the client.";
865 }
866 delete response;
867 });
868 }
869 #endif // !IS_MOBILE_PLATFORM
870 return OkStatus();
871 }
872
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)873 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
874 const std::vector<string>& remote_workers) {
875 #if !defined(IS_MOBILE_PLATFORM)
876 // Register multiple functions on selected remote workers.
877 uint64 context_id = GetContextId();
878 FunctionDefLibrary function_defs = func_lib_def_.ToProto();
879 std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
880 function_defs.function_size());
881 for (int i = 0; i < function_defs.function_size(); i++) {
882 requests[i] = std::make_shared<eager::EnqueueRequest>();
883 requests[i]->set_context_id(context_id);
884 eager::RegisterFunctionOp* register_function =
885 requests[i]->add_queue()->mutable_register_function();
886 *register_function->mutable_function_def() =
887 std::move(*function_defs.mutable_function(i));
888 StripDefaultAttributes(
889 *OpRegistry::Global(),
890 register_function->mutable_function_def()->mutable_node_def());
891 }
892
893 for (auto& remote_worker : remote_workers) {
894 core::RefCountPtr<eager::EagerClient> eager_client;
895 Status s = GetClient(remote_worker, &eager_client);
896 if (!s.ok()) {
897 continue;
898 }
899 for (int i = 0; i < requests.size(); i++) {
900 auto response = std::make_shared<eager::EnqueueResponse>();
901 eager_client->StreamingEnqueueAsync(
902 this->Executor().StreamingEnqueue(),
903 /*call_opts=*/nullptr, requests[i].get(), response.get(),
904 [request = requests[i], response](const Status& s) {
905 if (!s.ok()) {
906 LOG(ERROR) << "Failed to register function remotely due to "
907 << s.error_message()
908 << "\nThis could happen if the remote target has been "
909 "disconnected from the client.";
910 }
911 });
912 }
913 }
914 #endif // !IS_MOBILE_PLATFORM
915 return OkStatus();
916 }
917
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)918 Status EagerContext::AddFunctionDefWithStackTraces(
919 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
920 return AddFunctionDef(fdef, FunctionDefLibrary(),
921 /* add_to_local_only=*/false, stack_traces);
922 }
923
AddFunctionDef(const FunctionDef & fdef)924 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
925 return AddFunctionDef(fdef, FunctionDefLibrary(),
926 /* add_to_local_only=*/false);
927 }
928
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)929 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
930 const FunctionDefLibrary& library,
931 const bool add_to_local_only,
932 const StackTracesMap& stack_traces) {
933 bool is_first_ref = false;
934 {
935 mutex_lock l(cache_mu_);
936 auto* registered_function =
937 gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
938 if (registered_function == nullptr) {
939 registered_function = new RegisteredFunction;
940 registered_function->cached_kernel_keys =
941 std::make_unique<std::vector<Fprint128>>();
942 gtl::InsertOrUpdate(®istered_functions_, fdef.signature().name(),
943 registered_function);
944 } else {
945 // The function has been registered before. If the function is the same,
946 // then we take a Ref() otherwise we error out.
947 const FunctionDef* prev_fdef =
948 func_lib_def_.Find(fdef.signature().name());
949 if (prev_fdef == nullptr) {
950 return errors::Internal("Function: ", fdef.signature().name(),
951 " is in the cache but not in the library");
952 }
953 if (!FunctionDefsEqual(fdef, *prev_fdef)) {
954 return errors::InvalidArgument(
955 "Attempting to add a duplicate function with name: ",
956 fdef.signature().name(), " where the previous and current ",
957 "definitions differ. Previous definition: ",
958 prev_fdef->DebugString(),
959 " and current definition: ", fdef.DebugString());
960 }
961 registered_function->Ref();
962 }
963 is_first_ref = registered_function->RefCountIsOne();
964 if (is_first_ref) {
965 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
966 TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
967 }
968 }
969 if (is_first_ref && !add_to_local_only) {
970 return MaybeRegisterFunctionRemotely(fdef);
971 }
972 return OkStatus();
973 }
974
GetFunctionDef(const string & function_name)975 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
976 return func_lib_def_.Find(function_name);
977 }
978
ListFunctionNames()979 std::vector<string> EagerContext::ListFunctionNames() {
980 return func_lib_def_.ListFunctionNames();
981 }
982
RemoveFunction(const string & func)983 Status EagerContext::RemoveFunction(const string& func) {
984 // TODO(mdan): The context owns these functions. Why check refcount then?
985 mutex_lock l(cache_mu_);
986 auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
987 if (registered_function == nullptr) {
988 return errors::InvalidArgument("Tried to remove non-existent function '",
989 func, "'.");
990 }
991 bool is_last_ref = registered_function->RefCountIsOne();
992 if (is_last_ref) {
993 for (auto& key : *registered_function->cached_kernel_keys) {
994 kernel_cache_.erase(key);
995 }
996 registered_functions_.erase(func);
997 }
998 registered_function->Unref();
999 if (is_last_ref) {
1000 // TODO(fishx): Remove remote function as well.
1001 return func_lib_def_.RemoveFunction(func);
1002 }
1003 return OkStatus();
1004 }
1005
SyncExecutors()1006 Status EagerContext::SyncExecutors() {
1007 VLOG(6) << "Calling SyncExecutors";
1008 StatusGroup sg;
1009 // Synchronize on context default executor
1010 sg.Update(default_executor_.WaitForAllPendingNodes());
1011 default_executor_.ClearError();
1012
1013 // Synchronize thread local executors on client
1014 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
1015 {
1016 mutex_lock l(executor_map_mu_);
1017 executors_copy = thread_local_executor_;
1018 }
1019 for (const auto& entry : executors_copy) {
1020 sg.Update(entry.second->WaitForAllPendingNodes());
1021 entry.second->ClearError();
1022 }
1023
1024 #if !defined(IS_MOBILE_PLATFORM)
1025 auto remote_contexts = GetRemoteContexts();
1026 // Synchronize executors on remote workers
1027 eager::EnqueueRequest request;
1028 request.set_context_id(GetContextId());
1029 request.add_queue()->mutable_sync_remote_executor_for_stream();
1030 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
1031 std::vector<Status> statuses(remote_contexts.size());
1032
1033 for (int i = 0; i < remote_contexts.size(); i++) {
1034 const auto& target = remote_contexts[i];
1035 core::RefCountPtr<eager::EagerClient> eager_client;
1036 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
1037
1038 eager::EnqueueResponse* response = new eager::EnqueueResponse();
1039 eager_client->StreamingEnqueueAsync(
1040 this->Executor().StreamingEnqueue(),
1041 /*call_opts=*/nullptr, &request, response,
1042 [response, target, &counter, &s = statuses[i]](const Status& status) {
1043 s = status;
1044 delete response;
1045 counter.DecrementCount();
1046 });
1047 }
1048 counter.Wait();
1049 for (const Status& s : statuses) {
1050 sg.Update(s);
1051 }
1052 #endif // !IS_MOBILE_PLATFORM
1053
1054 // Reset the global rendezvous, which otherwise stores a failure state.
1055 ResetGlobalRendezvousForFunction();
1056
1057 return sg.as_summary_status();
1058 }
1059
GetCachedKernel(Fprint128 cache_key)1060 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
1061 Fprint128 cache_key) {
1062 tf_shared_lock l(cache_mu_);
1063 auto iter = kernel_cache_.find(cache_key);
1064 if (iter == kernel_cache_.end()) {
1065 return nullptr;
1066 }
1067 core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
1068 new_ref->Ref();
1069 return new_ref;
1070 }
1071
GetCachedDevice(Fprint128 device_cache_key)1072 Device* EagerContext::GetCachedDevice(Fprint128 device_cache_key) {
1073 tf_shared_lock l(device_cache_mu_);
1074 auto iter = device_cache_.find(device_cache_key);
1075 if (iter == device_cache_.end()) return nullptr;
1076 return iter->second;
1077 }
1078
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)1079 void EagerContext::AddKernelToCache(Fprint128 cache_key,
1080 KernelAndDevice* kernel) {
1081 mutex_lock ml(cache_mu_);
1082 core::RefCountPtr<KernelAndDevice> new_ref(kernel);
1083 new_ref->Ref();
1084 kernel_cache_[cache_key] = std::move(new_ref);
1085 auto* registered_function =
1086 gtl::FindPtrOrNull(registered_functions_, kernel->name());
1087 // The kernel name can be either a primitive op or a function.
1088 if (registered_function != nullptr) {
1089 registered_function->cached_kernel_keys->emplace_back(cache_key);
1090 }
1091 }
1092
AddDeviceToCache(Fprint128 device_cache_key,Device * device)1093 void EagerContext::AddDeviceToCache(Fprint128 device_cache_key,
1094 Device* device) {
1095 mutex_lock l(device_cache_mu_);
1096 device_cache_[device_cache_key] = device;
1097 }
1098
ShouldStoreGraphs()1099 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
1100
SetShouldStoreGraphs(bool value)1101 void EagerContext::SetShouldStoreGraphs(bool value) {
1102 mutex_lock ml(metadata_mu_);
1103 should_store_graphs_.store(value);
1104 if (!value) {
1105 run_metadata_.reset(new RunMetadata);
1106 }
1107 }
1108
FindDeviceFromName(const char * device_name,Device ** device) const1109 Status EagerContext::FindDeviceFromName(const char* device_name,
1110 Device** device) const {
1111 *device = HostCPU();
1112 if (device_name == nullptr || strlen(device_name) == 0) {
1113 return OkStatus();
1114 }
1115
1116 auto status = local_device_mgr()->LookupDevice(device_name, device);
1117 if (status.ok()) {
1118 return status;
1119 }
1120
1121 if (remote_device_mgr() != nullptr) {
1122 return remote_device_mgr()->LookupDevice(device_name, device);
1123 }
1124
1125 return status;
1126 }
1127
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const1128 Status EagerContext::FindCompositeDeviceFromName(
1129 StringPiece device_name, CompositeDevice** device) const {
1130 tf_shared_lock l(composite_devices_mu_);
1131 for (const auto& d : composite_devices_) {
1132 if (d.second->name() == device_name) {
1133 *device = d.second.get();
1134 return OkStatus();
1135 }
1136 }
1137 return errors::NotFound("Unknown composite device: ", device_name);
1138 }
1139
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)1140 Status EagerContext::RegisterCustomDevice(
1141 const string& device_name, std::unique_ptr<CustomDevice> device) {
1142 Device* existing_physical_device = nullptr;
1143 if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
1144 return errors::AlreadyExists(device_name,
1145 " already registered as a physical device.");
1146 }
1147 return custom_device_op_handler_.RegisterCustomDevice(device_name,
1148 std::move(device));
1149 }
1150
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)1151 Status EagerContext::FindOrCreateCompositeDevice(
1152 const std::vector<string>& underlying_devices, const string& device_name,
1153 CompositeDevice** composite_device) {
1154 if (!device_name.empty() &&
1155 FindCompositeDeviceFromName(device_name, composite_device).ok()) {
1156 return OkStatus();
1157 }
1158
1159 const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
1160
1161 mutex_lock l(composite_devices_mu_);
1162 auto iter = composite_devices_.find(hash_key);
1163 if (iter != composite_devices_.end()) {
1164 *composite_device = iter->second.get();
1165 return OkStatus();
1166 }
1167
1168 Status s;
1169 std::unique_ptr<CompositeDevice> device;
1170 if (device_name.empty()) {
1171 // Create a CompositeDevice on the same task as the host CPU, in order to
1172 // trigger packed TensorHandle copy from a client to a remote worker.
1173 device = CompositeDevice::MakeDevice(underlying_devices,
1174 composite_devices_.size(),
1175 HostCPU()->parsed_name(), &s);
1176 } else {
1177 device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
1178 }
1179 TF_RETURN_IF_ERROR(s);
1180 *composite_device = device.get();
1181 pflr_->AddCompositeDevice(*composite_device);
1182 composite_devices_.emplace(hash_key, std::move(device));
1183 return OkStatus();
1184 }
1185
OnSameTask(const Device * first,const Device * second) const1186 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
1187 if (first == nullptr) first = HostCPU();
1188 if (second == nullptr) second = HostCPU();
1189 return first->parsed_name().job == second->parsed_name().job &&
1190 first->parsed_name().replica == second->parsed_name().replica &&
1191 first->parsed_name().task == second->parsed_name().task;
1192 }
1193
1194 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const1195 Status EagerContext::CPUDeviceOnTask(const Device* device,
1196 Device** cpu_device) const {
1197 string cpu_device_name;
1198 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
1199 device->name(), &cpu_device_name));
1200
1201 return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
1202 }
1203
ClearResourceContainer(const string & name)1204 void EagerContext::ClearResourceContainer(const string& name) {
1205 // TODO(b/139809335): This does not properly clean up remote resources
1206 auto local_devices = local_device_mgr()->ListDevices();
1207 for (Device* device : local_devices) {
1208 // Only ignore container not found errors.
1209 device->resource_manager()->Cleanup(name).IgnoreError();
1210 }
1211 }
1212
GetGlobalRendezvousForFunctionLocalRendezvousStatus()1213 Status EagerContext::GetGlobalRendezvousForFunctionLocalRendezvousStatus() {
1214 mutex_lock l(global_rendezvous_mu_);
1215 IntraProcessRendezvous* rendezvous =
1216 local_rendezvous_table_->Find(kGlobalRendezvousId);
1217 if (rendezvous == nullptr) return OkStatus();
1218 Status s = rendezvous->GetLocalRendezvousStatus();
1219 rendezvous->Unref();
1220 return s;
1221 }
1222
UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr * device_mgr)1223 void EagerContext::UpdateGlobalRendezvousDeviceManager(
1224 tensorflow::DeviceMgr* device_mgr) {
1225 mutex_lock l(global_rendezvous_mu_);
1226 IntraProcessRendezvous* rendezvous =
1227 local_rendezvous_table_->Find(kGlobalRendezvousId);
1228 if (rendezvous == nullptr) return;
1229 rendezvous->UpdateDeviceManager(device_mgr);
1230 rendezvous->Unref();
1231 }
1232
1233 namespace {
GetTaskName(Device * d,string * task_name)1234 Status GetTaskName(Device* d, string* task_name) {
1235 string ignored;
1236 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
1237 return errors::InvalidArgument("Unable to parse device name: ", d->name());
1238 }
1239
1240 return OkStatus();
1241 }
1242 } // namespace
1243
1244 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)1245 Status EagerContext::GetClient(Device* device,
1246 core::RefCountPtr<eager::EagerClient>* client) {
1247 return GetClient(device->parsed_name(), client);
1248 }
1249
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)1250 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
1251 core::RefCountPtr<eager::EagerClient>* client) {
1252 string device_task_name;
1253 if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1254 return errors::InvalidArgument(
1255 "Task is not fully specified in device name: ",
1256 DeviceNameUtils::ParsedNameToString(device_name));
1257 }
1258
1259 {
1260 tf_shared_lock l(remote_state_mu_);
1261 if (remote_eager_workers_ == nullptr) {
1262 return errors::Internal(
1263 "Haven't set up remote eager worker in this eager context yet.");
1264 }
1265 TF_RETURN_IF_ERROR(
1266 remote_eager_workers_->GetClient(device_task_name, client));
1267
1268 if (*client == nullptr) {
1269 return errors::InvalidArgument(
1270 "Unable to find eager client corresponding to device ",
1271 DeviceNameUtils::ParsedNameToString(device_name));
1272 }
1273 if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1274 device_task_name) == remote_contexts_.end()) {
1275 return errors::Internal("Unable to find a context for handle on task: ",
1276 device_task_name, ". This should not happen.");
1277 }
1278 }
1279
1280 return OkStatus();
1281 }
1282
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1283 Status EagerContext::GetClient(const string& remote_task,
1284 core::RefCountPtr<eager::EagerClient>* client) {
1285 {
1286 tf_shared_lock l(remote_state_mu_);
1287 if (remote_eager_workers_ == nullptr) {
1288 return errors::Internal(
1289 "Haven't set up remote eager worker in this eager context yet.");
1290 }
1291 TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1292 }
1293
1294 if (*client == nullptr) {
1295 return errors::InvalidArgument(
1296 "Unable to find eager client corresponding to target ", remote_task);
1297 }
1298 return OkStatus();
1299 }
1300
GetContextId() const1301 uint64 EagerContext::GetContextId() const {
1302 tf_shared_lock l(remote_state_mu_);
1303 return context_id_;
1304 }
1305
GetContextViewId() const1306 uint64 EagerContext::GetContextViewId() const {
1307 tf_shared_lock l(remote_state_mu_);
1308 return context_view_id_;
1309 }
1310
IncrementContextViewId()1311 void EagerContext::IncrementContextViewId() {
1312 mutex_lock l(remote_state_mu_);
1313 context_view_id_ += 1;
1314 }
1315
EnableCollectiveOps(const ServerDef & server_def)1316 Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) {
1317 return distributed_manager_->EnableCollectiveOps(server_def);
1318 }
1319
1320 // Set collective ops related state in the context. Passing nullptr to
1321 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1322 Status EagerContext::StoreCollectiveOpsServer(
1323 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
1324 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1325 collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1326
1327 if (device_mgr != local_device_manager_.Get()) {
1328 if (local_device_manager_.Owned()) {
1329 old_local_device_managers_.push_back(
1330 std::move(local_device_manager_.owned_object));
1331 }
1332 local_device_manager_.Reset(device_mgr);
1333 UpdateGlobalRendezvousDeviceManager(local_device_manager_.Get());
1334 if (rendezvous_ != nullptr) rendezvous_->Unref();
1335 rendezvous_ = CreateRendezvous(-1);
1336 }
1337 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1338
1339 InitPrioritizedDeviceTypeList();
1340 ClearCachesAndThreadExecutors();
1341 default_executor_.ClearError();
1342 {
1343 tensorflow::mutex_lock l(executor_map_mu_);
1344 for (auto& entry : thread_local_executor_) {
1345 entry.second->ClearError();
1346 }
1347 }
1348
1349 const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1350 ResetPFLR(
1351 local_device_manager_.Get(), env_, /*config=*/config,
1352 TF_GRAPH_DEF_VERSION, &func_lib_def_,
1353 /*optimizer_options=*/
1354 config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1355 thread_pool_.get());
1356
1357 if (new_server != nullptr) {
1358 // Memory leak!
1359 if (server_ != nullptr) {
1360 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1361 "Servers don't support clean shutdown.";
1362 server_.release();
1363 }
1364 server_ = std::move(new_server);
1365 }
1366 DCHECK(server_ != nullptr);
1367
1368 return OkStatus();
1369 }
1370
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1371 Status EagerContext::SetRemoteDeviceFilters(
1372 const string& remote_worker, const std::vector<string>& device_filters) {
1373 // Get fully specified task name for remote worker
1374 string remote_worker_task_name;
1375 DeviceNameUtils::ParsedName pw;
1376 if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1377 return tensorflow::errors::InvalidArgument(
1378 "Remote worker task name is invalid ", remote_worker);
1379 }
1380 // Force set a replica as the key in cluster device filters map. I.e., if the
1381 // remote worker is `/job:worker/task:0` it then becomes
1382 // `/job:worker/replica:0/task:0`.
1383 pw.has_replica = true;
1384 if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1385 return tensorflow::errors::InvalidArgument(
1386 "Job name and task index must be specified for worker ", remote_worker);
1387 }
1388
1389 std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1390 for (auto& filter : device_filters) {
1391 DeviceNameUtils::ParsedName parsed_filter;
1392 if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1393 parsed_filters.emplace_back(parsed_filter);
1394 } else {
1395 return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1396 }
1397 }
1398
1399 if (VLOG_IS_ON(1)) {
1400 VLOG(1) << "Setting device filters for " << remote_worker << ":";
1401 for (auto& filter : device_filters) {
1402 VLOG(1) << " " << filter;
1403 }
1404 }
1405 mutex_lock l(remote_state_mu_);
1406 cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1407 return OkStatus();
1408 }
1409
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1410 void EagerContext::FilterDevicesForRemoteWorkers(
1411 const string& remote_worker,
1412 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1413 std::vector<bool>* filtered_device_mask) {
1414 filtered_device_mask->resize(device_attrs.size());
1415 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1416
1417 tf_shared_lock l(remote_state_mu_);
1418 auto it = cluster_device_filters_.find(remote_worker);
1419 // If no filters were specified, all devices should be visible to the worker
1420 if (it == cluster_device_filters_.end() || it->second.empty()) {
1421 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1422 return;
1423 }
1424
1425 const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1426 DeviceNameUtils::ParsedName parsed_remote_worker;
1427 DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1428 for (int i = 0; i < device_attrs.size(); i++) {
1429 DeviceNameUtils::ParsedName pn;
1430 DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1431 if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1432 // If this device is on the remote worker itself, it should be visible
1433 // regardless of device filters
1434 filtered_device_mask->at(i) = true;
1435 continue;
1436 }
1437 for (const auto& pf : parsed_filters) {
1438 if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1439 (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1440 (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1441 (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1442 (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1443 // Found a match, make it visible, stop processing more device filters
1444 filtered_device_mask->at(i) = true;
1445 break;
1446 }
1447 }
1448 }
1449 }
1450
SetWorkerEnv(WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session)1451 void EagerContext::SetWorkerEnv(WorkerEnv* worker_env,
1452 std::shared_ptr<WorkerSession> worker_session) {
1453 worker_env_ = worker_env;
1454 worker_session_ = worker_session;
1455 }
1456
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1457 Status EagerContext::InitializeRemoteMaster(
1458 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1459 std::shared_ptr<WorkerSession> worker_session,
1460 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1461 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1462 const std::vector<string>& remote_contexts, uint64 context_id,
1463 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
1464 DistributedFunctionLibraryRuntime* cluster_flr,
1465 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1466 remote_mgr) {
1467 if (context_id == kInvalidContextId) {
1468 return errors::InvalidArgument(
1469 "Failed to initialize remote for master context due to invalid ",
1470 "context id");
1471 }
1472
1473 if (!IsRemoteContextsEmpty()) {
1474 CloseAndClearAllRemoteContexts();
1475 }
1476 {
1477 mutex_lock l(remote_state_mu_);
1478 remote_contexts_ = remote_contexts;
1479 }
1480
1481 return SetMasterContextState(
1482 std::move(server), worker_env, std::move(worker_session),
1483 std::move(remote_eager_workers), std::move(remote_device_manager),
1484 context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1485 std::move(remote_mgr));
1486 }
1487
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1488 Status EagerContext::UpdateRemoteMaster(
1489 uint64 context_id,
1490 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1491 const std::vector<string>& add_remote_contexts,
1492 const std::vector<string>& remove_remote_contexts) {
1493 {
1494 tf_shared_lock l(remote_state_mu_);
1495 if (context_id != context_id_) {
1496 return errors::InvalidArgument(
1497 "Failed to update remote master context due to invalid context id. ",
1498 "Request id = ", context_id, " but current id = ", context_id_);
1499 }
1500 }
1501
1502 if (!remove_remote_contexts.empty()) {
1503 // N.B. remove_remote_contexts include both removed and replaced workers.
1504 // In the case where a worker is replaced by one that resolves to the same
1505 // `hostname:port`, it is safe to close context with the current view id,
1506 // since the newly created context on the remote worker will be holding
1507 // a larger view id and ignores this request.
1508 CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1509 mutex_lock l(remote_state_mu_);
1510 for (const string& remote_context : remove_remote_contexts) {
1511 remote_contexts_.erase(
1512 std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1513 remote_context),
1514 remote_contexts_.end());
1515 }
1516 }
1517 if (!add_remote_contexts.empty()) {
1518 mutex_lock l(remote_state_mu_);
1519 remote_contexts_.insert(std::end(remote_contexts_),
1520 std::begin(add_remote_contexts),
1521 std::end(add_remote_contexts));
1522 }
1523
1524 {
1525 mutex_lock l(remote_state_mu_);
1526 context_view_id_++;
1527
1528 remote_eager_workers_ = std::move(remote_eager_workers);
1529 pflr_->InitializeDeviceAndFlr();
1530 InitPrioritizedDeviceTypeList();
1531
1532 default_executor_.ClearError();
1533 {
1534 tensorflow::mutex_lock l(executor_map_mu_);
1535 for (auto& entry : thread_local_executor_) {
1536 entry.second->ClearError();
1537 }
1538 }
1539 }
1540
1541 // Register existing functions to the newly added remote workers. Note that
1542 // this should happen only after updating `remote_contexts_` because new
1543 // functions might be registered while we update the context. When that
1544 // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1545 // register the new functions on all remote workers (including the newly added
1546 // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1547 // registering existing functions, where duplicate registrations will be
1548 // ignored by the remote workers.
1549 TF_RETURN_IF_ERROR(
1550 RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1551 return OkStatus();
1552 }
1553
1554 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1555 Status EagerContext::SetMasterContextState(
1556 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1557 std::shared_ptr<WorkerSession> worker_session,
1558 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1559 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1560 uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1561 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1562 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1563 remote_mgr) {
1564 mutex_lock l(remote_state_mu_);
1565 is_master_ = true;
1566 context_id_ = context_id;
1567 context_view_id_ = context_view_id;
1568
1569 use_send_tensor_rpc_ =
1570 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1571
1572 if (local_device_mgr != local_device_manager_.Get()) {
1573 if (local_device_manager_.Owned()) {
1574 old_local_device_managers_.push_back(
1575 std::move(local_device_manager_.owned_object));
1576 }
1577 local_device_manager_.Reset(local_device_mgr);
1578 UpdateGlobalRendezvousDeviceManager(local_device_manager_.Get());
1579 }
1580 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1581
1582 if (rendezvous_ != nullptr) rendezvous_->Unref();
1583 rendezvous_ = r;
1584
1585 // Memory leak!
1586 if (server_ != nullptr) {
1587 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1588 "Servers don't support clean shutdown.";
1589 server_.release();
1590 }
1591 server_ = std::move(server);
1592
1593 remote_mgr_ = std::move(remote_mgr);
1594 worker_env_ = worker_env;
1595 worker_session_ = std::move(worker_session);
1596 remote_eager_workers_ = std::move(remote_eager_workers);
1597
1598 remote_device_manager_.Reset(std::move(remote_device_manager));
1599 ResetClusterFLR(cluster_flr);
1600
1601 InitPrioritizedDeviceTypeList();
1602
1603 ClearCachesAndThreadExecutors();
1604 default_executor_.ClearError();
1605 {
1606 tensorflow::mutex_lock l(executor_map_mu_);
1607 for (auto& entry : thread_local_executor_) {
1608 entry.second->ClearError();
1609 }
1610 }
1611 const auto* config = pflr_->config();
1612 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1613 &func_lib_def_, config->graph_options().optimizer_options(),
1614 thread_pool_.get(), cluster_flr_.Get());
1615
1616 keep_alive_secs_ = keep_alive_secs;
1617 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1618 // Only schedule a single closure.
1619 if (keep_alive_thread_ == nullptr) {
1620 keep_alive_thread_.reset(
1621 env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1622 while (true) {
1623 {
1624 {
1625 mutex_lock l(keep_alive_thread_shutdown_mu_);
1626
1627 if (shutting_down_) {
1628 return;
1629 }
1630
1631 keep_alive_thread_cv_.wait_for(
1632 l, std::chrono::seconds(sleep_for_secs_));
1633
1634 if (shutting_down_) {
1635 return;
1636 }
1637 }
1638 {
1639 mutex_lock l(remote_state_mu_);
1640 if (keep_alive_secs_ > 0) {
1641 {
1642 for (const auto& worker : remote_contexts_) {
1643 core::RefCountPtr<eager::EagerClient> client;
1644 Status s =
1645 remote_eager_workers_->GetClient(worker, &client);
1646
1647 if (!s.ok()) {
1648 LOG(WARNING) << "Keep-alive thread was unable to find "
1649 "a client for target "
1650 << worker << ". Got error: " << s;
1651 continue;
1652 }
1653
1654 eager::KeepAliveRequest* request =
1655 new eager::KeepAliveRequest;
1656 eager::KeepAliveResponse* response =
1657 new eager::KeepAliveResponse;
1658
1659 request->set_context_id(context_id_);
1660 client->KeepAliveAsync(
1661 request, response,
1662 [request, response](const Status& s) {
1663 delete request;
1664 delete response;
1665 });
1666 }
1667 }
1668 }
1669 }
1670 }
1671 }
1672 }));
1673 }
1674 return OkStatus();
1675 }
1676
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64_t)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1677 Status EagerContext::InitializeRemoteWorker(
1678 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1679 DynamicDeviceMgr* remote_device_mgr,
1680 const std::vector<string>& remote_contexts, uint64 context_id,
1681 uint64 context_view_id,
1682 std::function<Rendezvous*(const int64_t)> rendezvous_creator,
1683 DistributedFunctionLibraryRuntime* cluster_flr,
1684 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1685 remote_mgr,
1686 std::function<void()> resource_deallocator) {
1687 if (context_id == kInvalidContextId) {
1688 return errors::InvalidArgument(
1689 "Failed to initialize remote for worker context due to invalid ",
1690 "context id");
1691 }
1692 mutex_lock l(remote_state_mu_);
1693
1694 if (remote_device_manager_.Owned() || server_ != nullptr ||
1695 keep_alive_thread_ != nullptr) {
1696 return errors::FailedPrecondition(
1697 "EagerContext::InitializeRemoteWorker Failed. ",
1698 "Already initialized remote as a master context.");
1699 }
1700 is_master_ = false;
1701
1702 remote_contexts_ = remote_contexts;
1703 context_id_ = context_id;
1704 context_view_id_ = context_view_id;
1705
1706 rendezvous_creator_ = std::move(rendezvous_creator);
1707 remote_eager_workers_ = std::move(remote_eager_workers);
1708 remote_mgr_ = std::move(remote_mgr);
1709 ResetClusterFLR(cluster_flr);
1710
1711 remote_device_manager_.Reset(remote_device_mgr);
1712
1713 const auto* config = pflr_->config();
1714 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1715 &func_lib_def_, config->graph_options().optimizer_options(),
1716 thread_pool_.get(), cluster_flr_.Get());
1717 InitPrioritizedDeviceTypeList();
1718
1719 ClearCachesAndThreadExecutors();
1720 default_executor_.ClearError();
1721 {
1722 tensorflow::mutex_lock l(executor_map_mu_);
1723 for (auto& entry : thread_local_executor_) {
1724 entry.second->ClearError();
1725 }
1726 }
1727
1728 resource_deallocator_ = std::move(resource_deallocator);
1729
1730 return OkStatus();
1731 }
1732
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1733 Status EagerContext::UpdateRemoteWorker(
1734 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1735 const std::vector<string>& remote_contexts, uint64 context_id) {
1736 {
1737 mutex_lock l(remote_state_mu_);
1738 if (context_id != context_id_) {
1739 return errors::InvalidArgument(
1740 "Failed to update remote for worker context due to invalid ",
1741 "context id. Request id = ", context_id,
1742 " but current id = ", context_id_);
1743 }
1744 context_view_id_++;
1745
1746 remote_contexts_ = remote_contexts;
1747 remote_eager_workers_ = std::move(remote_eager_workers);
1748 InitPrioritizedDeviceTypeList();
1749 pflr_->InitializeDeviceAndFlr();
1750 }
1751
1752 // No need to update remote_device_manager_ since it's not owned for remote
1753 // worker context (owned by the corresponding worker session).
1754 if (remote_device_manager_.Owned()) {
1755 return errors::FailedPrecondition(
1756 "EagerContext::UpdateRemoteWorker failed because the context was "
1757 "initialized as a master context.");
1758 }
1759
1760 ClearCachesAndThreadExecutors();
1761 default_executor_.ClearError();
1762 {
1763 tensorflow::mutex_lock l(executor_map_mu_);
1764 for (auto& entry : thread_local_executor_) {
1765 entry.second->ClearError();
1766 }
1767 }
1768 return OkStatus();
1769 }
1770 #endif // !IS_MOBILE_PLATFORM
1771
1772 } // namespace tensorflow
1773