1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context.h"
17
18 #include <memory>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/c/eager/immediate_execution_context.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
26 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
27 #include "tensorflow/core/lib/core/refcount.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/nccl/collective_communicator.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/platform.h"
33 // clang-format on
34
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/c/tf_tensor_internal.h"
37 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
38 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
39 #include "tensorflow/core/common_runtime/colocation_graph.h"
40 #include "tensorflow/core/common_runtime/device_resolver_local.h"
41 #include "tensorflow/core/common_runtime/device_set.h"
42 #include "tensorflow/core/common_runtime/process_util.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #if !defined(IS_MOBILE_PLATFORM)
50 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
51 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
52 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
53 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
54 #endif // !IS_MOBILE_PLATFORM
55 #include "tensorflow/core/framework/resource_mgr.h"
56 #include "tensorflow/core/lib/core/blocking_counter.h"
57 #include "tensorflow/core/lib/monitoring/gauge.h"
58 #include "tensorflow/core/util/env_var.h"
59
60 namespace tensorflow {
61 namespace {
62
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)63 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
64 bool val;
65 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
66 return val;
67 }
68 return default_val;
69 }
70
71 auto* eager_context_created =
72 monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
73 "True if an eager context was created.");
74
75 } // namespace
76
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr,CollectiveExecutorMgrInterface * collective_executor_mgr,bool run_eager_op_as_function)77 EagerContext::EagerContext(
78 const SessionOptions& opts,
79 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
80 DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
81 DistributedFunctionLibraryRuntime* cluster_flr,
82 CollectiveExecutorMgrInterface* collective_executor_mgr,
83 bool run_eager_op_as_function)
84 : ImmediateExecutionContext(kEager),
85 opts_(opts),
86 default_device_placement_policy_(default_device_placement_policy),
87 local_device_manager_(device_mgr, device_mgr_owned),
88 host_cpu_device_(device_mgr->HostCPU()),
89 rendezvous_(rendezvous),
90 thread_pool_(NewThreadPoolFromSessionOptions(opts)),
91 cluster_flr_(cluster_flr),
92 log_device_placement_(opts.config.log_device_placement()),
93 allow_soft_placement_(opts.config.allow_soft_placement()),
94 num_active_steps_(0),
95 step_container_(std::make_unique<ScopedStepContainer>(
96 0, [this](const string& name) { ClearResourceContainer(name); })),
97 default_executor_(async),
98 log_memory_(LogMemory::IsEnabled()),
99 env_(opts.env),
100 collective_executor_mgr_(collective_executor_mgr, /*owned=*/false),
101 use_send_tensor_rpc_(false),
102 pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
103 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)),
104 run_eager_op_as_function_(run_eager_op_as_function) {
105 ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
106 &func_lib_def_, opts.config.graph_options().optimizer_options(),
107 thread_pool_.get(), cluster_flr);
108 // Starts exporting metrics through a platform-specific monitoring API (if
109 // provided). For builds using "tensorflow/core/platform/default", this is
110 // currently a no-op.
111 eager_context_created->GetCell()->Set(true);
112 InitPrioritizedDeviceTypeList();
__anon4433f2400302(std::function<void()> closure) 113 runner_ = [this](std::function<void()> closure) {
114 this->thread_pool_->Schedule(std::move(closure));
115 };
116
117 run_metadata_ = std::make_unique<RunMetadata>();
118
119 #if !defined(IS_MOBILE_PLATFORM)
120 context_id_ = kInvalidContextId;
121 context_view_id_ = 0;
122 #endif // IS_MOBILE_PLATFORM
123
124 // TODO(yuefengz): consider creating a new RpcCollectiveExecutorMgr each
125 // time.
126 if (collective_executor_mgr_.Get() == nullptr) {
127 collective_executor_mgr_.Reset(CreateProdLocalCollectiveExecutorMgr(
128 opts.config, local_device_mgr(),
129 MaybeCreateNcclCommunicator(opts.config)));
130 }
131 global_rendezvous_for_functions_ =
132 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
133 }
134
CreateInt64Scalar(int64_t value)135 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64_t value) {
136 return new TensorInterface(Tensor(value));
137 }
138
CreateUint64Scalar(uint64 value)139 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
140 return new TensorInterface(Tensor(value));
141 }
142
CreateInt32Scalar(int32_t value)143 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32_t value) {
144 return new TensorInterface(Tensor(value));
145 }
146
CreateFloatScalar(float value)147 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
148 return new TensorInterface(Tensor(value));
149 }
150
CreateDoubleScalar(double value)151 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
152 return new TensorInterface(Tensor(value));
153 }
154
CreateHalfScalar(Eigen::half value)155 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
156 return new TensorInterface(Tensor(value));
157 }
158
CreateStringScalar(tstring value)159 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
160 return new TensorInterface(Tensor(value));
161 }
162
CreateComplex128Scalar(complex128 value)163 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
164 complex128 value) {
165 return new TensorInterface(Tensor(value));
166 }
167
CreateBoolScalar(bool value)168 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
169 return new TensorInterface(Tensor(value));
170 }
171
CreateTensor(DataType dtype,absl::Span<const int64> dim_sizes)172 AbstractTensorInterface* EagerContext::CreateTensor(
173 DataType dtype, absl::Span<const int64> dim_sizes) {
174 return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
175 }
176
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)177 AbstractTensorInterface* EagerContext::CreateTensor(
178 DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
179 MemoryReleaser memory_releaser, void* memory_releaser_arg) {
180 TF_Tensor* tensor_wrapper =
181 TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
182 memory_releaser, memory_releaser_arg);
183
184 AbstractTensorInterface* result = nullptr;
185 std::swap(result, tensor_wrapper->tensor);
186 TF_DeleteTensor(tensor_wrapper);
187 return result;
188 }
189
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)190 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
191 const ConfigProto* config, int graph_def_version,
192 const FunctionLibraryDefinition* lib_def,
193 const OptimizerOptions& optimizer_options,
194 thread::ThreadPool* thread_pool,
195 DistributedFunctionLibraryRuntime* cluster_flr) {
196 Rendezvous::Factory rendezvous_factory{
197 [this](const int64_t step_id, const DeviceMgr*, Rendezvous** r) {
198 *r = CreateRendezvous(step_id);
199 return Status::OK();
200 }};
201 pflr_.reset(new ProcessFunctionLibraryRuntime(
202 device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
203 thread_pool, cluster_flr,
204 /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
205 }
206
InitPrioritizedDeviceTypeList()207 void EagerContext::InitPrioritizedDeviceTypeList() {
208 DeviceSet ds;
209 for (Device* d : local_device_mgr()->ListDevices()) {
210 ds.AddDevice(d);
211 }
212 auto remote_device_manager = remote_device_mgr();
213 if (remote_device_manager != nullptr) {
214 for (Device* d : remote_device_manager->ListDevices()) {
215 ds.AddDevice(d);
216 }
217 }
218 mutex_lock l(device_type_list_mu_);
219 prioritized_device_type_list_ =
220 std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
221 }
222
223 namespace {
224 // Using absl::StrJoin with lambda does not work in tf-lite builds.
225 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)226 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
227 std::vector<string> v;
228 v.reserve(devices.size());
229 for (const auto& p : devices) {
230 v.push_back(p.first->name());
231 }
232 return v;
233 }
234
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)235 std::vector<string> DeviceTypesToString(
236 const PrioritizedDeviceTypeVector& types) {
237 std::vector<string> v;
238 v.reserve(types.size());
239 for (const auto& p : types) {
240 v.push_back(p.first.type_string());
241 }
242 return v;
243 }
244
245 // Selects the "best" device that both exists and is supported.
246 //
247 // The `existing` argument specifies the available devices in the system, in
248 // priority order. The `supported` argument specifies the supported device types
249 // and their priorities, lower index types having higher priority.
250 // Currently the type priority defined by the `supported` parameter takes
251 // precedence over system device priorities from `existing`.
252 //
253 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)254 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
255 const PrioritizedDeviceVector& existing,
256 const PrioritizedDeviceTypeVector& supported) {
257 for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
258 for (const std::pair<Device*, int32>& prioritized_device : existing) {
259 Device* dev = prioritized_device.first;
260 if (DeviceType(dev->attributes().device_type()) ==
261 prioritized_type.first &&
262 DeviceNameUtils::IsCompleteSpecification(pattern,
263 dev->parsed_name())) {
264 return dev;
265 }
266 }
267 }
268 return nullptr;
269 }
270
271 } // namespace
272
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const273 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
274 const NodeDef& ndef, Device** out) const {
275 DCHECK(out != nullptr);
276
277 PrioritizedDeviceTypeVector supported_devs;
278 auto device_type_list = prioritized_device_type_list();
279 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
280 *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
281 if (supported_devs.empty()) {
282 return errors::NotFound("Could not find device for node: ",
283 errors::FormatNodeNameForError(ndef.name()), " = ",
284 ndef.op(), "[", SummarizeAttrs(ndef), "]",
285 "\nAll kernels registered for op ", ndef.op(),
286 ":\n", KernelsRegisteredForOp(ndef.op()));
287 }
288
289 // Select the first matching registered device from the supported device
290 // list. If nothing matches and soft placement is enabled, pick a suitable
291 // device from the available ones.
292 const auto pflr_device_set = pflr()->device_set();
293 const PrioritizedDeviceVector& existing =
294 pflr_device_set->prioritized_devices();
295 *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
296 if (*out != nullptr) {
297 return Status::OK();
298 }
299
300 if (AllowSoftPlacement()) {
301 DeviceNameUtils::ParsedName soft_device_name = preferred;
302 soft_device_name.type.clear();
303 soft_device_name.has_type = false;
304 soft_device_name.has_id = false;
305 // TODO(b/148213746): Soft placement logic picks up another task if the
306 // requested does not exist.
307 *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
308 if (*out != nullptr) {
309 return Status::OK();
310 }
311 }
312
313 if (DeviceNameUtils::HasSomeDetails(preferred)) {
314 return errors::InvalidArgument(
315 "Could not satisfy device specification '", preferred,
316 "'. enable_soft_placement=", AllowSoftPlacement(),
317 ". Supported device types [",
318 absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
319 "]. All available devices [",
320 absl::StrJoin(DevicesToString(existing), ", "), "].");
321 }
322 return errors::InvalidArgument(
323 "No supported device found in available devices [",
324 absl::StrJoin(DevicesToString(existing), ", "),
325 "]. enable_soft_placement=", AllowSoftPlacement(),
326 ". Supported devices types [",
327 absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
328 }
329
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)330 void EagerContext::ResetClusterFLR(
331 DistributedFunctionLibraryRuntime* cluster_flr) {
332 cluster_flr_.Reset(cluster_flr, /*owned=*/true);
333 }
334
Executor()335 EagerExecutor& EagerContext::Executor() {
336 tf_shared_lock l(executor_map_mu_);
337 return *gtl::FindWithDefault(thread_local_executor_,
338 std::this_thread::get_id(), &default_executor_);
339 }
340
SetExecutorForThread(EagerExecutor * executor)341 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
342 tensorflow::mutex_lock l(executor_map_mu_);
343 if (executor == &default_executor_) {
344 thread_local_executor_.erase(std::this_thread::get_id());
345 } else {
346 auto thread_id = std::this_thread::get_id();
347 thread_local_executor_[thread_id] = executor;
348 auto& executors_with_cleanups = has_cleanup_[thread_id];
349 if (executors_with_cleanups.find(executor) ==
350 executors_with_cleanups.end()) {
351 executors_with_cleanups.insert(executor);
352 // If the executor is deleted before this context, we need to remove it
353 // from the map to avoid attempting to sync it in our destructor.
354 std::function<void()> cleanup([this, thread_id, executor]() {
355 {
356 tensorflow::mutex_lock l(executor_map_mu_);
357 auto existing = thread_local_executor_.find(thread_id);
358 if (existing != thread_local_executor_.end() &&
359 existing->second == executor) {
360 thread_local_executor_.erase(thread_id);
361 }
362 has_cleanup_[thread_id].erase(executor);
363 }
364 });
365 executor->AddCleanup(reinterpret_cast<intptr_t>(this),
366 std::move(cleanup));
367 }
368 }
369 }
370
ClearCachesAndThreadExecutors()371 void EagerContext::ClearCachesAndThreadExecutors() {
372 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
373 {
374 mutex_lock l(executor_map_mu_);
375 executors_copy = thread_local_executor_;
376 }
377 for (const auto& entry : executors_copy) {
378 entry.second->WaitForAllPendingNodes().IgnoreError();
379 }
380 ClearCachesAndDefaultExecutor();
381 }
382
ClearCachesAndDefaultExecutor()383 void EagerContext::ClearCachesAndDefaultExecutor() {
384 // The executor stores pointers to kernels, so we need to make sure that no
385 // async eager ops are still executing. We lock the cache during this time
386 // as well.
387 mutex_lock ml(cache_mu_);
388 default_executor_.WaitForAllPendingNodes().IgnoreError();
389 kernel_cache_.clear();
390 for (auto& entry : registered_functions_) {
391 entry.second->cached_kernel_keys->clear();
392 }
393 {
394 mutex_lock ml(metadata_mu_);
395 step_container_.reset(new ScopedStepContainer(
396 0, [this](const string& name) { ClearResourceContainer(name); }));
397 }
398 }
399
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)400 void EagerContext::SetThreadLocalDevicePlacementPolicy(
401 ContextDevicePlacementPolicy policy) {
402 mutex_lock ml(policy_map_mu_);
403 device_placement_policy_[std::this_thread::get_id()] = policy;
404 }
405
GetDevicePlacementPolicy() const406 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
407 tf_shared_lock l(policy_map_mu_);
408 auto policy_map_it =
409 device_placement_policy_.find(std::this_thread::get_id());
410 if (policy_map_it != device_placement_policy_.end()) {
411 return policy_map_it->second;
412 }
413 return default_device_placement_policy_;
414 }
415
416 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()417 std::vector<string> EagerContext::GetRemoteContexts() {
418 tf_shared_lock l(remote_state_mu_);
419 return remote_contexts_;
420 }
421
IsRemoteContextsEmpty()422 bool EagerContext::IsRemoteContextsEmpty() {
423 tf_shared_lock l(remote_state_mu_);
424 return remote_contexts_.empty();
425 }
426
CloseAndClearAllRemoteContexts()427 void EagerContext::CloseAndClearAllRemoteContexts() {
428 uint64 context_id;
429 uint64 context_view_id;
430 std::vector<string> remote_contexts_copy;
431 {
432 mutex_lock l(remote_state_mu_);
433 if (!is_master_) return;
434 context_id = context_id_;
435 context_view_id = context_view_id_;
436 context_id_ = kInvalidContextId;
437 // Forget the current view id and reset to the starting value 0.
438 context_view_id_ = 0;
439
440 // Make a copy of remote targets to avoid holding the lock when sending
441 // close context requests.
442 remote_contexts_copy = remote_contexts_;
443 remote_contexts_.clear();
444 }
445 CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
446 }
447
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)448 void EagerContext::CloseRemoteContexts(
449 const std::vector<string>& remote_contexts, uint64 context_id,
450 uint64 context_view_id) {
451 // Close all remote contexts.
452 eager::CloseContextRequest request;
453 request.set_context_id(context_id);
454 request.set_context_view_id(context_view_id);
455 // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
456 // request to closed remote workers.
457 std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
458 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
459
460 int i = 0;
461 for (const auto& worker : remote_contexts) {
462 core::RefCountPtr<eager::EagerClient> client;
463 Status s = GetClient(worker, &client);
464
465 client->CloseContextAsync(
466 &request, &responses[i],
467 [&worker, &counter, context_id](const Status& s) {
468 if (!s.ok()) {
469 LOG(ERROR) << "Unable to close remote context with ID "
470 << context_id << " for worker: " << worker << " due to "
471 << s.error_message();
472 }
473 counter.DecrementCount();
474 });
475 i++;
476 }
477
478 counter.Wait();
479 }
480
481 #endif // !IS_MOBILE_PLATFORM
482
WaitForAndCloseRemoteContexts()483 void EagerContext::WaitForAndCloseRemoteContexts() {
484 ClearCachesAndThreadExecutors();
485
486 #if !defined(IS_MOBILE_PLATFORM)
487 {
488 mutex_lock l(keep_alive_thread_shutdown_mu_);
489 shutting_down_ = true;
490 keep_alive_thread_cv_.notify_all();
491 }
492 keep_alive_thread_.reset();
493
494 if (!IsRemoteContextsEmpty()) {
495 CloseAndClearAllRemoteContexts();
496 }
497
498 {
499 mutex_lock l(remote_state_mu_);
500
501 default_executor_.ShutDown().IgnoreError();
502 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
503 {
504 mutex_lock l(executor_map_mu_);
505 executors_copy = thread_local_executor_;
506 }
507 for (const auto& it : executors_copy) {
508 it.second->ShutDown().IgnoreError();
509 }
510
511 // This shuts down the completion queue and joins the thread polling it.
512 // The thread exits only after the completion queue has been drained of all
513 // the events. These events' completion should invoke all remaining RPC
514 // callbacks.
515 // This also deletes all EagerClient instances. There should not be any
516 // references to EagerClients left after all RPCs and async ops have been
517 // finished.
518 remote_eager_workers_ = nullptr;
519 }
520 #endif // !IS_MOBILE_PLATFORM
521 }
522
~EagerContext()523 EagerContext::~EagerContext() {
524 // TODO(iga): Add a separate API method to shutdown EagerContext so that we
525 // don't send RPCs and block in destructor.
526 WaitForAndCloseRemoteContexts();
527
528 // Custom devices may have obtained references to various context components
529 // (executors, thread pool). It's safer to run their destructors early.
530 custom_device_op_handler_.Clear();
531
532 ClearCachesAndThreadExecutors();
533 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
534 {
535 mutex_lock l(executor_map_mu_);
536 executors_copy = thread_local_executor_;
537 }
538 for (const auto& entry : executors_copy) {
539 // Let the executor know that its cleanup closure is no longer valid.
540 entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
541 }
542 for (auto& entry : registered_functions_) {
543 while (!entry.second->Unref()) {
544 // remove all references.
545 }
546 }
547 registered_functions_.clear();
548
549 #if !defined(IS_MOBILE_PLATFORM)
550 if (server_) {
551 // TODO(b/136478427): Fix this.
552 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
553 "Servers don't support clean shutdown.";
554 server_.release();
555 }
556
557 {
558 mutex_lock l(keep_alive_thread_shutdown_mu_);
559 shutting_down_ = true;
560 keep_alive_thread_cv_.notify_all();
561 }
562 keep_alive_thread_.reset();
563 if (!remote_contexts_.empty()) {
564 CloseAndClearAllRemoteContexts();
565 }
566 #endif // !IS_MOBILE_PLATFORM
567
568 if (rendezvous_) {
569 rendezvous_->Unref();
570 }
571 if (resource_deallocator_ != nullptr) {
572 resource_deallocator_();
573 }
574 }
575
FindFunctionByName(const string & name) const576 bool EagerContext::FindFunctionByName(const string& name) const {
577 return func_lib_def_.Find(name) != nullptr;
578 }
579
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)580 Status EagerContext::FindFunctionOpData(
581 const string& name, const tensorflow::OpRegistrationData** op_data) {
582 return func_lib_def_.LookUp(name, op_data);
583 }
584
FindFunctionDef(const string & name) const585 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
586 return func_lib_def_.Find(name);
587 }
588
ExportRunMetadata()589 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
590 mutex_lock ml(metadata_mu_);
591 auto result = std::make_unique<RunMetadata>();
592 run_metadata_.swap(result);
593 return result;
594 }
595
UsesTFRT()596 bool EagerContext::UsesTFRT() { return false; }
597
RunEagerOpAsFunction() const598 bool EagerContext::RunEagerOpAsFunction() const {
599 return run_eager_op_as_function_;
600 }
601
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)602 void EagerContext::ListDevices(
603 std::vector<tensorflow::DeviceAttributes>* devices) {
604 local_device_mgr()->ListDeviceAttributes(devices);
605 if (remote_device_mgr()) {
606 remote_device_mgr()->ListDeviceAttributes(devices);
607 }
608 }
609
AddDevices(std::vector<std::unique_ptr<Device>> devices)610 Status EagerContext::AddDevices(std::vector<std::unique_ptr<Device>> devices) {
611 std::vector<std::unique_ptr<Device>> local_devices, remote_devices;
612 while (!devices.empty()) {
613 if (devices.front()->IsLocal()) {
614 local_devices.push_back(std::move(devices.front()));
615 } else {
616 remote_devices.push_back(std::move(devices.front()));
617 }
618 devices.erase(devices.begin());
619 }
620 TF_RETURN_IF_ERROR(
621 reinterpret_cast<DynamicDeviceMgr*>(local_device_manager_.Get())
622 ->AddDevices(std::move(local_devices)));
623
624 if (!remote_devices.empty()) {
625 if (!remote_device_mgr()) {
626 remote_device_manager_.Reset(
627 std::make_unique<tensorflow::DynamicDeviceMgr>());
628 }
629 TF_RETURN_IF_ERROR(
630 reinterpret_cast<DynamicDeviceMgr*>(remote_device_manager_.Get())
631 ->AddDevices(std::move(remote_devices)));
632 }
633
634 // Add the devices to pflr's device set.
635 pflr_->InitializeDeviceAndFlr();
636 InitPrioritizedDeviceTypeList();
637 return Status::OK();
638 }
639
StartStep()640 void EagerContext::StartStep() {
641 mutex_lock ml(metadata_mu_);
642 num_active_steps_++;
643 }
644
EndStep()645 void EagerContext::EndStep() {
646 mutex_lock ml(metadata_mu_);
647 num_active_steps_--;
648 if (num_active_steps_ == 0) {
649 // TODO(b/139809335): This does not properly clean up remote resources
650 // Clean up the previous step container and create a new one.
651 step_container_.reset(new ScopedStepContainer(
652 0, [this](const string& name) { ClearResourceContainer(name); }));
653 }
654 }
655
StepContainer()656 ScopedStepContainer* EagerContext::StepContainer() {
657 mutex_lock ml(metadata_mu_);
658 return step_container_.get();
659 }
660
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)661 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
662 // Only client context can register function on remote worker context.
663 if (!remote_device_manager_.Owned()) return Status::OK();
664 #if !defined(IS_MOBILE_PLATFORM)
665 std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
666 request->set_context_id(GetContextId());
667
668 eager::RegisterFunctionOp* register_function =
669 request->add_queue()->mutable_register_function();
670 *register_function->mutable_function_def() = fdef;
671 StripDefaultAttributes(
672 *OpRegistry::Global(),
673 register_function->mutable_function_def()->mutable_node_def());
674
675 auto remote_contexts = GetRemoteContexts();
676 for (const auto& target : remote_contexts) {
677 core::RefCountPtr<eager::EagerClient> eager_client;
678 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
679
680 eager::EnqueueResponse* response = new eager::EnqueueResponse();
681 eager_client->StreamingEnqueueAsync(
682 /*call_opts=*/nullptr, request.get(), response,
683 [request, response](const Status& status) {
684 if (!status.ok()) {
685 LOG(ERROR) << "Failed to register function remotely due to "
686 << status.error_message()
687 << "\nThis could happen if the remote target has been "
688 "disconnected from the client.";
689 }
690 delete response;
691 });
692 }
693 #endif // !IS_MOBILE_PLATFORM
694 return Status::OK();
695 }
696
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)697 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
698 const std::vector<string>& remote_workers) {
699 #if !defined(IS_MOBILE_PLATFORM)
700 // Register multiple functions on selected remote workers.
701 uint64 context_id = GetContextId();
702 FunctionDefLibrary function_defs = func_lib_def_.ToProto();
703 std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
704 function_defs.function_size());
705 for (int i = 0; i < function_defs.function_size(); i++) {
706 requests[i] = std::make_shared<eager::EnqueueRequest>();
707 requests[i]->set_context_id(context_id);
708 eager::RegisterFunctionOp* register_function =
709 requests[i]->add_queue()->mutable_register_function();
710 *register_function->mutable_function_def() =
711 std::move(*function_defs.mutable_function(i));
712 StripDefaultAttributes(
713 *OpRegistry::Global(),
714 register_function->mutable_function_def()->mutable_node_def());
715 }
716
717 for (auto& remote_worker : remote_workers) {
718 core::RefCountPtr<eager::EagerClient> eager_client;
719 Status s = GetClient(remote_worker, &eager_client);
720 if (!s.ok()) {
721 continue;
722 }
723 for (int i = 0; i < requests.size(); i++) {
724 auto response = std::make_shared<eager::EnqueueResponse>();
725 eager_client->StreamingEnqueueAsync(
726 /*call_opts=*/nullptr, requests[i].get(), response.get(),
727 [request = requests[i], response](const Status& s) {
728 if (!s.ok()) {
729 LOG(ERROR) << "Failed to register function remotely due to "
730 << s.error_message()
731 << "\nThis could happen if the remote target has been "
732 "disconnected from the client.";
733 }
734 });
735 }
736 }
737 #endif // !IS_MOBILE_PLATFORM
738 return Status::OK();
739 }
740
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)741 Status EagerContext::AddFunctionDefWithStackTraces(
742 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
743 return AddFunctionDef(fdef, FunctionDefLibrary(),
744 /* add_to_local_only=*/false, stack_traces);
745 }
746
AddFunctionDef(const FunctionDef & fdef)747 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
748 return AddFunctionDef(fdef, FunctionDefLibrary(),
749 /* add_to_local_only=*/false);
750 }
751
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)752 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
753 const FunctionDefLibrary& library,
754 const bool add_to_local_only,
755 const StackTracesMap& stack_traces) {
756 bool is_first_ref = false;
757 {
758 mutex_lock l(cache_mu_);
759 auto* registered_function =
760 gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
761 if (registered_function == nullptr) {
762 registered_function = new RegisteredFunction;
763 registered_function->cached_kernel_keys =
764 absl::make_unique<std::vector<Fprint128>>();
765 gtl::InsertOrUpdate(®istered_functions_, fdef.signature().name(),
766 registered_function);
767 } else {
768 // The function has been registered before. If the function is the same,
769 // then we take a Ref() otherwise we error out.
770 const FunctionDef* prev_fdef =
771 func_lib_def_.Find(fdef.signature().name());
772 if (prev_fdef == nullptr) {
773 return errors::Internal("Function: ", fdef.signature().name(),
774 " is in the cache but not in the library");
775 }
776 if (!FunctionDefsEqual(fdef, *prev_fdef)) {
777 return errors::InvalidArgument(
778 "Attempting to add a duplicate function with name: ",
779 fdef.signature().name(), " where the previous and current ",
780 "definitions differ. Previous definition: ",
781 prev_fdef->DebugString(),
782 " and current definition: ", fdef.DebugString());
783 }
784 registered_function->Ref();
785 }
786 is_first_ref = registered_function->RefCountIsOne();
787 }
788 if (is_first_ref) {
789 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
790 TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
791 if (!add_to_local_only) {
792 return MaybeRegisterFunctionRemotely(fdef);
793 }
794 }
795 return Status::OK();
796 }
797
GetFunctionDef(const string & function_name)798 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
799 return func_lib_def_.Find(function_name);
800 }
801
ListFunctionNames()802 std::vector<string> EagerContext::ListFunctionNames() {
803 return func_lib_def_.ListFunctionNames();
804 }
805
RemoveFunction(const string & func)806 Status EagerContext::RemoveFunction(const string& func) {
807 bool is_last_ref = false;
808 {
809 mutex_lock l(cache_mu_);
810 auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
811 if (registered_function == nullptr) {
812 return errors::InvalidArgument("Tried to remove non-existent function '",
813 func, "'.");
814 }
815 is_last_ref = registered_function->RefCountIsOne();
816 if (is_last_ref) {
817 for (auto& key : *registered_function->cached_kernel_keys) {
818 kernel_cache_.erase(key);
819 }
820 registered_functions_.erase(func);
821 }
822 registered_function->Unref();
823 }
824 if (is_last_ref) {
825 // TODO(fishx): Remove remote function as well.
826 return func_lib_def_.RemoveFunction(func);
827 }
828 return Status::OK();
829 }
830
SyncExecutors()831 Status EagerContext::SyncExecutors() {
832 StatusGroup sg;
833 // Synchronize on context default executor
834 sg.Update(default_executor_.WaitForAllPendingNodes());
835 default_executor_.ClearError();
836
837 // Synchronize thread local executors on client
838 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
839 {
840 mutex_lock l(executor_map_mu_);
841 executors_copy = thread_local_executor_;
842 }
843 for (const auto& entry : executors_copy) {
844 sg.Update(entry.second->WaitForAllPendingNodes());
845 entry.second->ClearError();
846 }
847
848 #if !defined(IS_MOBILE_PLATFORM)
849 auto remote_contexts = GetRemoteContexts();
850 // Synchronize executors on remote workers
851 eager::EnqueueRequest request;
852 request.set_context_id(GetContextId());
853 request.add_queue()->mutable_sync_remote_executor_for_stream();
854 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
855 std::vector<Status> statuses(remote_contexts.size());
856
857 for (int i = 0; i < remote_contexts.size(); i++) {
858 const auto& target = remote_contexts[i];
859 core::RefCountPtr<eager::EagerClient> eager_client;
860 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
861
862 eager::EnqueueResponse* response = new eager::EnqueueResponse();
863 eager_client->StreamingEnqueueAsync(
864 /*call_opts=*/nullptr, &request, response,
865 [response, target, &counter, &s = statuses[i]](const Status& status) {
866 s = status;
867 delete response;
868 counter.DecrementCount();
869 });
870 }
871 counter.Wait();
872 for (const Status& s : statuses) {
873 sg.Update(s);
874 }
875 #endif // !IS_MOBILE_PLATFORM
876 {
877 // Reset the global function rendezvous, which otherwise stores a failure
878 // state.
879 mutex_lock l(global_rendezvous_mu_);
880 global_rendezvous_for_functions_ =
881 core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
882 }
883 return sg.as_summary_status();
884 }
885
GetCachedKernel(Fprint128 cache_key)886 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
887 Fprint128 cache_key) {
888 tf_shared_lock l(cache_mu_);
889 auto iter = kernel_cache_.find(cache_key);
890 if (iter == kernel_cache_.end()) {
891 return nullptr;
892 }
893 core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
894 new_ref->Ref();
895 return new_ref;
896 }
897
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)898 void EagerContext::AddKernelToCache(Fprint128 cache_key,
899 KernelAndDevice* kernel) {
900 mutex_lock ml(cache_mu_);
901 core::RefCountPtr<KernelAndDevice> new_ref(kernel);
902 new_ref->Ref();
903 kernel_cache_[cache_key] = std::move(new_ref);
904 auto* registered_function =
905 gtl::FindPtrOrNull(registered_functions_, kernel->name());
906 // The kernel name can be either a primitive op or a function.
907 if (registered_function != nullptr) {
908 registered_function->cached_kernel_keys->emplace_back(cache_key);
909 }
910 }
911
ShouldStoreGraphs()912 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
913
SetShouldStoreGraphs(bool value)914 void EagerContext::SetShouldStoreGraphs(bool value) {
915 mutex_lock ml(metadata_mu_);
916 should_store_graphs_.store(value);
917 if (!value) {
918 run_metadata_.reset(new RunMetadata);
919 }
920 }
921
FindDeviceFromName(const char * device_name,Device ** device) const922 Status EagerContext::FindDeviceFromName(const char* device_name,
923 Device** device) const {
924 *device = HostCPU();
925 if (device_name == nullptr || strlen(device_name) == 0) {
926 return Status::OK();
927 }
928
929 auto status = local_device_mgr()->LookupDevice(device_name, device);
930 if (status.ok()) {
931 return status;
932 }
933
934 if (remote_device_mgr() != nullptr) {
935 return remote_device_mgr()->LookupDevice(device_name, device);
936 }
937
938 return status;
939 }
940
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const941 Status EagerContext::FindCompositeDeviceFromName(
942 StringPiece device_name, CompositeDevice** device) const {
943 tf_shared_lock l(composite_devices_mu_);
944 for (const auto& d : composite_devices_) {
945 if (d.second->name() == device_name) {
946 *device = d.second.get();
947 return Status::OK();
948 }
949 }
950 return errors::NotFound("Unknown composite device: ", device_name);
951 }
952
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)953 Status EagerContext::RegisterCustomDevice(
954 const string& device_name, std::unique_ptr<CustomDevice> device) {
955 Device* existing_physical_device = nullptr;
956 if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
957 return errors::AlreadyExists(device_name,
958 " already registered as a physical device.");
959 }
960 return custom_device_op_handler_.RegisterCustomDevice(device_name,
961 std::move(device));
962 }
963
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)964 Status EagerContext::FindOrCreateCompositeDevice(
965 const std::vector<string>& underlying_devices, const string& device_name,
966 CompositeDevice** composite_device) {
967 if (!device_name.empty() &&
968 FindCompositeDeviceFromName(device_name, composite_device).ok()) {
969 return Status::OK();
970 }
971
972 const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
973
974 mutex_lock l(composite_devices_mu_);
975 auto iter = composite_devices_.find(hash_key);
976 if (iter != composite_devices_.end()) {
977 *composite_device = iter->second.get();
978 return Status::OK();
979 }
980
981 Status s;
982 std::unique_ptr<CompositeDevice> device;
983 if (device_name.empty()) {
984 // Create a CompositeDevice on the same task as the host CPU, in order to
985 // trigger packed TensorHandle copy from a client to a remote worker.
986 device = CompositeDevice::MakeDevice(underlying_devices,
987 composite_devices_.size(),
988 HostCPU()->parsed_name(), &s);
989 } else {
990 device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
991 }
992 TF_RETURN_IF_ERROR(s);
993 *composite_device = device.get();
994 pflr_->AddCompositeDevice(*composite_device);
995 composite_devices_.emplace(hash_key, std::move(device));
996 return Status::OK();
997 }
998
OnSameTask(const Device * first,const Device * second) const999 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
1000 if (first == nullptr) first = HostCPU();
1001 if (second == nullptr) second = HostCPU();
1002 return first->parsed_name().job == second->parsed_name().job &&
1003 first->parsed_name().replica == second->parsed_name().replica &&
1004 first->parsed_name().task == second->parsed_name().task;
1005 }
1006
1007 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const1008 Status EagerContext::CPUDeviceOnTask(const Device* device,
1009 Device** cpu_device) const {
1010 string cpu_device_name;
1011 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
1012 device->name(), &cpu_device_name));
1013
1014 return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
1015 }
1016
ClearResourceContainer(const string & name)1017 void EagerContext::ClearResourceContainer(const string& name) {
1018 // TODO(b/139809335): This does not properly clean up remote resources
1019 auto local_devices = local_device_mgr()->ListDevices();
1020 for (Device* device : local_devices) {
1021 // Only ignore container not found errors.
1022 device->resource_manager()->Cleanup(name).IgnoreError();
1023 }
1024 }
1025
1026 namespace {
GetTaskName(Device * d,string * task_name)1027 Status GetTaskName(Device* d, string* task_name) {
1028 string ignored;
1029 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
1030 return errors::InvalidArgument("Unable to parse device name: ", d->name());
1031 }
1032
1033 return Status::OK();
1034 }
1035 } // namespace
1036
1037 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)1038 Status EagerContext::GetClient(Device* device,
1039 core::RefCountPtr<eager::EagerClient>* client) {
1040 return GetClient(device->parsed_name(), client);
1041 }
1042
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)1043 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
1044 core::RefCountPtr<eager::EagerClient>* client) {
1045 string device_task_name;
1046 if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1047 return errors::InvalidArgument(
1048 "Task is not fully specified in device name: ",
1049 DeviceNameUtils::ParsedNameToString(device_name));
1050 }
1051
1052 {
1053 tf_shared_lock l(remote_state_mu_);
1054 if (remote_eager_workers_ == nullptr) {
1055 return errors::Internal(
1056 "Haven't set up remote eager worker in this eager context yet.");
1057 }
1058 TF_RETURN_IF_ERROR(
1059 remote_eager_workers_->GetClient(device_task_name, client));
1060
1061 if (*client == nullptr) {
1062 return errors::InvalidArgument(
1063 "Unable to find eager client corresponding to device ",
1064 DeviceNameUtils::ParsedNameToString(device_name));
1065 }
1066 if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1067 device_task_name) == remote_contexts_.end()) {
1068 return errors::Internal("Unable to find a context for handle on task: ",
1069 device_task_name, ". This should not happen.");
1070 }
1071 }
1072
1073 return Status::OK();
1074 }
1075
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1076 Status EagerContext::GetClient(const string& remote_task,
1077 core::RefCountPtr<eager::EagerClient>* client) {
1078 {
1079 tf_shared_lock l(remote_state_mu_);
1080 if (remote_eager_workers_ == nullptr) {
1081 return errors::Internal(
1082 "Haven't set up remote eager worker in this eager context yet.");
1083 }
1084 TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1085 }
1086
1087 if (*client == nullptr) {
1088 return errors::InvalidArgument(
1089 "Unable to find eager client corresponding to target ", remote_task);
1090 }
1091 return Status::OK();
1092 }
1093
GetContextId() const1094 uint64 EagerContext::GetContextId() const {
1095 tf_shared_lock l(remote_state_mu_);
1096 return context_id_;
1097 }
1098
GetContextViewId() const1099 uint64 EagerContext::GetContextViewId() const {
1100 tf_shared_lock l(remote_state_mu_);
1101 return context_view_id_;
1102 }
1103
IncrementContextViewId()1104 void EagerContext::IncrementContextViewId() {
1105 mutex_lock l(remote_state_mu_);
1106 context_view_id_ += 1;
1107 }
1108
EnableCollectiveOps(const ServerDef & server_def)1109 Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) {
1110 return distributed_manager_->EnableCollectiveOps(server_def);
1111 }
1112
1113 // Set collective ops related state in the context. Passing nullptr to
1114 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1115 Status EagerContext::StoreCollectiveOpsServer(
1116 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
1117 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1118 collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1119
1120 if (device_mgr != local_device_manager_.Get()) {
1121 if (local_device_manager_.Owned()) {
1122 old_local_device_managers_.push_back(
1123 std::move(local_device_manager_.owned_object));
1124 }
1125 local_device_manager_.Reset(device_mgr);
1126 if (rendezvous_ != nullptr) rendezvous_->Unref();
1127 rendezvous_ = CreateRendezvous(-1);
1128 }
1129 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1130
1131 InitPrioritizedDeviceTypeList();
1132 ClearCachesAndThreadExecutors();
1133 default_executor_.ClearError();
1134 {
1135 tensorflow::mutex_lock l(executor_map_mu_);
1136 for (auto& entry : thread_local_executor_) {
1137 entry.second->ClearError();
1138 }
1139 }
1140
1141 const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1142 ResetPFLR(
1143 local_device_manager_.Get(), env_, /*config=*/config,
1144 TF_GRAPH_DEF_VERSION, &func_lib_def_,
1145 /*optimizer_options=*/
1146 config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1147 thread_pool_.get());
1148
1149 if (new_server != nullptr) {
1150 // Memory leak!
1151 if (server_ != nullptr) {
1152 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1153 "Servers don't support clean shutdown.";
1154 server_.release();
1155 }
1156 server_ = std::move(new_server);
1157 }
1158 DCHECK(server_ != nullptr);
1159
1160 return Status::OK();
1161 }
1162
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1163 Status EagerContext::SetRemoteDeviceFilters(
1164 const string& remote_worker, const std::vector<string>& device_filters) {
1165 // Get fully specified task name for remote worker
1166 string remote_worker_task_name;
1167 DeviceNameUtils::ParsedName pw;
1168 if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1169 return tensorflow::errors::InvalidArgument(
1170 "Remote worker task name is invalid ", remote_worker);
1171 }
1172 // Force set a replica as the key in cluster device filters map. I.e., if the
1173 // remote worker is `/job:worker/task:0` it then becomes
1174 // `/job:worker/replica:0/task:0`.
1175 pw.has_replica = true;
1176 if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1177 return tensorflow::errors::InvalidArgument(
1178 "Job name and task index must be specified for worker ", remote_worker);
1179 }
1180
1181 std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1182 for (auto& filter : device_filters) {
1183 DeviceNameUtils::ParsedName parsed_filter;
1184 if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1185 parsed_filters.emplace_back(parsed_filter);
1186 } else {
1187 return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1188 }
1189 }
1190
1191 if (VLOG_IS_ON(1)) {
1192 VLOG(1) << "Setting device filters for " << remote_worker << ":";
1193 for (auto& filter : device_filters) {
1194 VLOG(1) << " " << filter;
1195 }
1196 }
1197 mutex_lock l(remote_state_mu_);
1198 cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1199 return Status::OK();
1200 }
1201
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1202 void EagerContext::FilterDevicesForRemoteWorkers(
1203 const string& remote_worker,
1204 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1205 std::vector<bool>* filtered_device_mask) {
1206 filtered_device_mask->resize(device_attrs.size());
1207 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1208
1209 tf_shared_lock l(remote_state_mu_);
1210 auto it = cluster_device_filters_.find(remote_worker);
1211 // If no filters were specified, all devices should be visible to the worker
1212 if (it == cluster_device_filters_.end() || it->second.empty()) {
1213 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1214 return;
1215 }
1216
1217 const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1218 DeviceNameUtils::ParsedName parsed_remote_worker;
1219 DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1220 for (int i = 0; i < device_attrs.size(); i++) {
1221 DeviceNameUtils::ParsedName pn;
1222 DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1223 if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1224 // If this device is on the remote worker itself, it should be visible
1225 // regardless of device filters
1226 filtered_device_mask->at(i) = true;
1227 continue;
1228 }
1229 for (const auto& pf : parsed_filters) {
1230 if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1231 (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1232 (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1233 (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1234 (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1235 // Found a match, make it visible, stop processing more device filters
1236 filtered_device_mask->at(i) = true;
1237 break;
1238 }
1239 }
1240 }
1241 }
1242
SetWorkerEnv(WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session)1243 void EagerContext::SetWorkerEnv(WorkerEnv* worker_env,
1244 std::shared_ptr<WorkerSession> worker_session) {
1245 worker_env_ = worker_env;
1246 worker_session_ = worker_session;
1247 }
1248
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1249 Status EagerContext::InitializeRemoteMaster(
1250 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1251 std::shared_ptr<WorkerSession> worker_session,
1252 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1253 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1254 const std::vector<string>& remote_contexts, uint64 context_id,
1255 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
1256 DistributedFunctionLibraryRuntime* cluster_flr,
1257 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1258 remote_mgr) {
1259 if (context_id == kInvalidContextId) {
1260 return errors::InvalidArgument(
1261 "Failed to initialize remote for master context due to invalid ",
1262 "context id");
1263 }
1264
1265 if (!IsRemoteContextsEmpty()) {
1266 CloseAndClearAllRemoteContexts();
1267 }
1268 {
1269 mutex_lock l(remote_state_mu_);
1270 remote_contexts_ = remote_contexts;
1271 }
1272
1273 return SetMasterContextState(
1274 std::move(server), worker_env, std::move(worker_session),
1275 std::move(remote_eager_workers), std::move(remote_device_manager),
1276 context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1277 std::move(remote_mgr));
1278 }
1279
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1280 Status EagerContext::UpdateRemoteMaster(
1281 uint64 context_id,
1282 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1283 const std::vector<string>& add_remote_contexts,
1284 const std::vector<string>& remove_remote_contexts) {
1285 {
1286 tf_shared_lock l(remote_state_mu_);
1287 if (context_id != context_id_) {
1288 return errors::InvalidArgument(
1289 "Failed to update remote master context due to invalid context id. ",
1290 "Request id = ", context_id, " but current id = ", context_id_);
1291 }
1292 }
1293
1294 if (!remove_remote_contexts.empty()) {
1295 // N.B. remove_remote_contexts include both removed and replaced workers.
1296 // In the case where a worker is replaced by one that resolves to the same
1297 // `hostname:port`, it is safe to close context with the current view id,
1298 // since the newly created context on the remote worker will be holding
1299 // a larger view id and ignores this request.
1300 CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1301 mutex_lock l(remote_state_mu_);
1302 for (const string& remote_context : remove_remote_contexts) {
1303 remote_contexts_.erase(
1304 std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1305 remote_context),
1306 remote_contexts_.end());
1307 }
1308 }
1309 if (!add_remote_contexts.empty()) {
1310 mutex_lock l(remote_state_mu_);
1311 remote_contexts_.insert(std::end(remote_contexts_),
1312 std::begin(add_remote_contexts),
1313 std::end(add_remote_contexts));
1314 }
1315
1316 {
1317 mutex_lock l(remote_state_mu_);
1318 context_view_id_++;
1319
1320 remote_eager_workers_ = std::move(remote_eager_workers);
1321 pflr_->InitializeDeviceAndFlr();
1322 InitPrioritizedDeviceTypeList();
1323
1324 default_executor_.ClearError();
1325 {
1326 tensorflow::mutex_lock l(executor_map_mu_);
1327 for (auto& entry : thread_local_executor_) {
1328 entry.second->ClearError();
1329 }
1330 }
1331 }
1332
1333 // Register existing functions to the newly added remote workers. Note that
1334 // this should happen only after updating `remote_contexts_` because new
1335 // functions might be registered while we update the context. When that
1336 // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1337 // register the new functions on all remote workers (including the newly added
1338 // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1339 // registering existing functions, where duplicate registrations will be
1340 // ignored by the remote workers.
1341 TF_RETURN_IF_ERROR(
1342 RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1343 return Status::OK();
1344 }
1345
1346 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1347 Status EagerContext::SetMasterContextState(
1348 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1349 std::shared_ptr<WorkerSession> worker_session,
1350 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1351 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1352 uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1353 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1354 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1355 remote_mgr) {
1356 mutex_lock l(remote_state_mu_);
1357 is_master_ = true;
1358 context_id_ = context_id;
1359 context_view_id_ = context_view_id;
1360
1361 use_send_tensor_rpc_ =
1362 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1363
1364 if (local_device_mgr != local_device_manager_.Get()) {
1365 if (local_device_manager_.Owned()) {
1366 old_local_device_managers_.push_back(
1367 std::move(local_device_manager_.owned_object));
1368 }
1369 local_device_manager_.Reset(local_device_mgr);
1370 }
1371 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1372
1373 if (rendezvous_ != nullptr) rendezvous_->Unref();
1374 rendezvous_ = r;
1375
1376 // Memory leak!
1377 if (server_ != nullptr) {
1378 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1379 "Servers don't support clean shutdown.";
1380 server_.release();
1381 }
1382 server_ = std::move(server);
1383
1384 remote_mgr_ = std::move(remote_mgr);
1385 worker_env_ = worker_env;
1386 worker_session_ = std::move(worker_session);
1387 remote_eager_workers_ = std::move(remote_eager_workers);
1388
1389 remote_device_manager_.Reset(std::move(remote_device_manager));
1390 ResetClusterFLR(cluster_flr);
1391
1392 InitPrioritizedDeviceTypeList();
1393
1394 ClearCachesAndThreadExecutors();
1395 default_executor_.ClearError();
1396 {
1397 tensorflow::mutex_lock l(executor_map_mu_);
1398 for (auto& entry : thread_local_executor_) {
1399 entry.second->ClearError();
1400 }
1401 }
1402 const auto* config = pflr_->config();
1403 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1404 &func_lib_def_, config->graph_options().optimizer_options(),
1405 thread_pool_.get(), cluster_flr_.Get());
1406
1407 keep_alive_secs_ = keep_alive_secs;
1408 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1409 // Only schedule a single closure.
1410 if (keep_alive_thread_ == nullptr) {
1411 keep_alive_thread_.reset(
1412 env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1413 while (true) {
1414 {
1415 {
1416 mutex_lock l(keep_alive_thread_shutdown_mu_);
1417
1418 if (shutting_down_) {
1419 return;
1420 }
1421
1422 keep_alive_thread_cv_.wait_for(
1423 l, std::chrono::seconds(sleep_for_secs_));
1424
1425 if (shutting_down_) {
1426 return;
1427 }
1428 }
1429 {
1430 mutex_lock l(remote_state_mu_);
1431 if (keep_alive_secs_ > 0) {
1432 {
1433 for (const auto& worker : remote_contexts_) {
1434 core::RefCountPtr<eager::EagerClient> client;
1435 Status s =
1436 remote_eager_workers_->GetClient(worker, &client);
1437
1438 if (!s.ok()) {
1439 LOG(WARNING) << "Keep-alive thread was unable to find "
1440 "a client for target "
1441 << worker << ". Got error: " << s;
1442 continue;
1443 }
1444
1445 eager::KeepAliveRequest* request =
1446 new eager::KeepAliveRequest;
1447 eager::KeepAliveResponse* response =
1448 new eager::KeepAliveResponse;
1449
1450 request->set_context_id(context_id_);
1451 client->KeepAliveAsync(
1452 request, response,
1453 [request, response](const Status& s) {
1454 delete request;
1455 delete response;
1456 });
1457 }
1458 }
1459 }
1460 }
1461 }
1462 }
1463 }));
1464 }
1465 return Status::OK();
1466 }
1467
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64_t)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1468 Status EagerContext::InitializeRemoteWorker(
1469 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1470 DynamicDeviceMgr* remote_device_mgr,
1471 const std::vector<string>& remote_contexts, uint64 context_id,
1472 uint64 context_view_id,
1473 std::function<Rendezvous*(const int64_t)> rendezvous_creator,
1474 DistributedFunctionLibraryRuntime* cluster_flr,
1475 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1476 remote_mgr,
1477 std::function<void()> resource_deallocator) {
1478 if (context_id == kInvalidContextId) {
1479 return errors::InvalidArgument(
1480 "Failed to initialize remote for worker context due to invalid ",
1481 "context id");
1482 }
1483 mutex_lock l(remote_state_mu_);
1484
1485 if (remote_device_manager_.Owned() || server_ != nullptr ||
1486 keep_alive_thread_ != nullptr) {
1487 return errors::FailedPrecondition(
1488 "EagerContext::InitializeRemoteWorker Failed. ",
1489 "Already initialized remote as a master context.");
1490 }
1491 is_master_ = false;
1492
1493 remote_contexts_ = remote_contexts;
1494 context_id_ = context_id;
1495 context_view_id_ = context_view_id;
1496
1497 rendezvous_creator_ = std::move(rendezvous_creator);
1498 remote_eager_workers_ = std::move(remote_eager_workers);
1499 remote_mgr_ = std::move(remote_mgr);
1500 ResetClusterFLR(cluster_flr);
1501
1502 remote_device_manager_.Reset(remote_device_mgr);
1503
1504 const auto* config = pflr_->config();
1505 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1506 &func_lib_def_, config->graph_options().optimizer_options(),
1507 thread_pool_.get(), cluster_flr_.Get());
1508 InitPrioritizedDeviceTypeList();
1509
1510 ClearCachesAndThreadExecutors();
1511 default_executor_.ClearError();
1512 {
1513 tensorflow::mutex_lock l(executor_map_mu_);
1514 for (auto& entry : thread_local_executor_) {
1515 entry.second->ClearError();
1516 }
1517 }
1518
1519 resource_deallocator_ = std::move(resource_deallocator);
1520
1521 return Status::OK();
1522 }
1523
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1524 Status EagerContext::UpdateRemoteWorker(
1525 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1526 const std::vector<string>& remote_contexts, uint64 context_id) {
1527 {
1528 mutex_lock l(remote_state_mu_);
1529 if (context_id != context_id_) {
1530 return errors::InvalidArgument(
1531 "Failed to update remote for worker context due to invalid ",
1532 "context id. Request id = ", context_id,
1533 " but current id = ", context_id_);
1534 }
1535 context_view_id_++;
1536
1537 remote_contexts_ = remote_contexts;
1538 remote_eager_workers_ = std::move(remote_eager_workers);
1539 InitPrioritizedDeviceTypeList();
1540 pflr_->InitializeDeviceAndFlr();
1541 }
1542
1543 // No need to update remote_device_manager_ since it's not owned for remote
1544 // worker context (owned by the corresponding worker session).
1545 if (remote_device_manager_.Owned()) {
1546 return errors::FailedPrecondition(
1547 "EagerContext::UpdateRemoteWorker failed because the context was "
1548 "initialized as a master context.");
1549 }
1550
1551 ClearCachesAndThreadExecutors();
1552 default_executor_.ClearError();
1553 {
1554 tensorflow::mutex_lock l(executor_map_mu_);
1555 for (auto& entry : thread_local_executor_) {
1556 entry.second->ClearError();
1557 }
1558 }
1559 return Status::OK();
1560 }
1561 #endif // !IS_MOBILE_PLATFORM
1562
1563 } // namespace tensorflow
1564