1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context.h"
17
18 #include <memory>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/c/eager/immediate_execution_context.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
26 #include "tensorflow/core/lib/core/refcount.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/nccl/collective_communicator.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/platform.h"
32 // clang-format on
33
34 #include "tensorflow/c/tf_tensor.h"
35 #include "tensorflow/c/tf_tensor_internal.h"
36 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
37 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
38 #include "tensorflow/core/common_runtime/colocation_graph.h"
39 #include "tensorflow/core/common_runtime/device_resolver_local.h"
40 #include "tensorflow/core/common_runtime/device_set.h"
41 #include "tensorflow/core/common_runtime/process_util.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/protobuf/config.pb.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 #if !defined(IS_MOBILE_PLATFORM)
49 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
50 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
51 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
52 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
53 #endif // !IS_MOBILE_PLATFORM
54 #include "tensorflow/core/framework/resource_mgr.h"
55 #include "tensorflow/core/lib/core/blocking_counter.h"
56 #include "tensorflow/core/lib/monitoring/gauge.h"
57 #include "tensorflow/core/util/env_var.h"
58
59 namespace tensorflow {
60 namespace {
61
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)62 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
63 bool val;
64 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
65 return val;
66 }
67 return default_val;
68 }
69
70 auto* eager_context_created =
71 monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
72 "True if an eager context was created.");
73
74 } // namespace
75
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,bool async,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,DistributedFunctionLibraryRuntime * cluster_flr)76 EagerContext::EagerContext(
77 const SessionOptions& opts,
78 ContextDevicePlacementPolicy default_device_placement_policy, bool async,
79 const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
80 DistributedFunctionLibraryRuntime* cluster_flr)
81 : ImmediateExecutionContext(kEager),
82 opts_(opts),
83 default_device_placement_policy_(default_device_placement_policy),
84 local_device_manager_(device_mgr, device_mgr_owned),
85 host_cpu_device_(device_mgr->HostCPU()),
86 rendezvous_(rendezvous),
87 thread_pool_(NewThreadPoolFromSessionOptions(opts)),
88 cluster_flr_(cluster_flr),
89 log_device_placement_(opts.config.log_device_placement()),
90 allow_soft_placement_(opts.config.allow_soft_placement()),
91 num_active_steps_(0),
92 step_container_(std::make_unique<ScopedStepContainer>(
93 0, [this](const string& name) { ClearResourceContainer(name); })),
94 default_executor_(async),
95 log_memory_(LogMemory::IsEnabled()),
96 env_(opts.env),
97 use_send_tensor_rpc_(false),
98 pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
99 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
100 ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
101 &func_lib_def_, opts.config.graph_options().optimizer_options(),
102 thread_pool_.get(), cluster_flr);
103 // Starts exporting metrics through a platform-specific monitoring API (if
104 // provided). For builds using "tensorflow/core/platform/default", this is
105 // currently a no-op.
106 eager_context_created->GetCell()->Set(true);
107 InitPrioritizedDeviceTypeList();
__anonceee8e1b0302(std::function<void()> closure) 108 runner_ = [this](std::function<void()> closure) {
109 this->thread_pool_->Schedule(std::move(closure));
110 };
111
112 run_metadata_ = std::make_unique<RunMetadata>();
113
114 #if !defined(IS_MOBILE_PLATFORM)
115 context_id_ = kInvalidContextId;
116 context_view_id_ = 0;
117 #endif // IS_MOBILE_PLATFORM
118
119 std::unique_ptr<DeviceResolverInterface> drl(
120 new DeviceResolverLocal(local_device_mgr()));
121 std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
122 opts.config, local_device_mgr(), drl.get(),
123 "/job:localhost/replica:0/task:0"));
124 collective_executor_mgr_.Reset(
125 new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
126 std::move(cprl), MaybeCreateNcclCommunicator()),
127 /*owned=*/true);
128 }
129
CreateInt64Scalar(int64 value)130 AbstractTensorInterface* EagerContext::CreateInt64Scalar(int64 value) {
131 return new TensorInterface(Tensor(value));
132 }
133
CreateUint64Scalar(uint64 value)134 AbstractTensorInterface* EagerContext::CreateUint64Scalar(uint64 value) {
135 return new TensorInterface(Tensor(value));
136 }
137
CreateInt32Scalar(int32 value)138 AbstractTensorInterface* EagerContext::CreateInt32Scalar(int32 value) {
139 return new TensorInterface(Tensor(value));
140 }
141
CreateFloatScalar(float value)142 AbstractTensorInterface* EagerContext::CreateFloatScalar(float value) {
143 return new TensorInterface(Tensor(value));
144 }
145
CreateDoubleScalar(double value)146 AbstractTensorInterface* EagerContext::CreateDoubleScalar(double value) {
147 return new TensorInterface(Tensor(value));
148 }
149
CreateHalfScalar(Eigen::half value)150 AbstractTensorInterface* EagerContext::CreateHalfScalar(Eigen::half value) {
151 return new TensorInterface(Tensor(value));
152 }
153
CreateStringScalar(tstring value)154 AbstractTensorInterface* EagerContext::CreateStringScalar(tstring value) {
155 return new TensorInterface(Tensor(value));
156 }
157
CreateComplex128Scalar(complex128 value)158 AbstractTensorInterface* EagerContext::CreateComplex128Scalar(
159 complex128 value) {
160 return new TensorInterface(Tensor(value));
161 }
162
CreateBoolScalar(bool value)163 AbstractTensorInterface* EagerContext::CreateBoolScalar(bool value) {
164 return new TensorInterface(Tensor(value));
165 }
166
CreateTensor(DataType dtype,absl::Span<const int64> dim_sizes)167 AbstractTensorInterface* EagerContext::CreateTensor(
168 DataType dtype, absl::Span<const int64> dim_sizes) {
169 return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
170 }
171
CreateTensor(DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,MemoryReleaser memory_releaser,void * memory_releaser_arg)172 AbstractTensorInterface* EagerContext::CreateTensor(
173 DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
174 MemoryReleaser memory_releaser, void* memory_releaser_arg) {
175 TF_Tensor* tensor_wrapper =
176 TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
177 memory_releaser, memory_releaser_arg);
178
179 AbstractTensorInterface* result = nullptr;
180 std::swap(result, tensor_wrapper->tensor);
181 TF_DeleteTensor(tensor_wrapper);
182 return result;
183 }
184
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr)185 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
186 const ConfigProto* config, int graph_def_version,
187 const FunctionLibraryDefinition* lib_def,
188 const OptimizerOptions& optimizer_options,
189 thread::ThreadPool* thread_pool,
190 DistributedFunctionLibraryRuntime* cluster_flr) {
191 Rendezvous::Factory rendezvous_factory{
192 [this](const int64 step_id, const DeviceMgr*, Rendezvous** r) {
193 *r = CreateRendezvous(step_id);
194 return Status::OK();
195 }};
196 pflr_.reset(new ProcessFunctionLibraryRuntime(
197 device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
198 thread_pool, cluster_flr,
199 /*session_metadata=*/nullptr, std::move(rendezvous_factory)));
200 }
201
InitPrioritizedDeviceTypeList()202 void EagerContext::InitPrioritizedDeviceTypeList() {
203 DeviceSet ds;
204 for (Device* d : local_device_mgr()->ListDevices()) {
205 ds.AddDevice(d);
206 }
207 auto remote_device_manager = remote_device_mgr();
208 if (remote_device_manager != nullptr) {
209 for (Device* d : remote_device_manager->ListDevices()) {
210 ds.AddDevice(d);
211 }
212 }
213 mutex_lock l(device_type_list_mu_);
214 prioritized_device_type_list_ =
215 std::make_shared<std::vector<DeviceType>>(ds.PrioritizedDeviceTypeList());
216 }
217
218 namespace {
219 // Using absl::StrJoin with lambda does not work in tf-lite builds.
220 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const PrioritizedDeviceVector & devices)221 std::vector<string> DevicesToString(const PrioritizedDeviceVector& devices) {
222 std::vector<string> v;
223 v.reserve(devices.size());
224 for (const auto& p : devices) {
225 v.push_back(p.first->name());
226 }
227 return v;
228 }
229
DeviceTypesToString(const PrioritizedDeviceTypeVector & types)230 std::vector<string> DeviceTypesToString(
231 const PrioritizedDeviceTypeVector& types) {
232 std::vector<string> v;
233 v.reserve(types.size());
234 for (const auto& p : types) {
235 v.push_back(p.first.type_string());
236 }
237 return v;
238 }
239
240 // Selects the "best" device that both exists and is supported.
241 //
242 // The `existing` argument specifies the available devices in the system, in
243 // priority order. The `supported` argument specifies the supported device types
244 // and their priorities, lower index types having higher priority.
245 // Currently the type priority defined by the `supported` parameter takes
246 // precedence over system device priorities from `existing`.
247 //
248 // TODO(b/148213212): Allow setting default device in eager context.
SelectBestMatchingDevice(const DeviceNameUtils::ParsedName & pattern,const PrioritizedDeviceVector & existing,const PrioritizedDeviceTypeVector & supported)249 Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern,
250 const PrioritizedDeviceVector& existing,
251 const PrioritizedDeviceTypeVector& supported) {
252 for (const std::pair<DeviceType, int32>& prioritized_type : supported) {
253 for (const std::pair<Device*, int32>& prioritized_device : existing) {
254 Device* dev = prioritized_device.first;
255 if (DeviceType(dev->attributes().device_type()) ==
256 prioritized_type.first &&
257 DeviceNameUtils::IsCompleteSpecification(pattern,
258 dev->parsed_name())) {
259 return dev;
260 }
261 }
262 }
263 return nullptr;
264 }
265
266 } // namespace
267
SelectDevice(DeviceNameUtils::ParsedName preferred,const NodeDef & ndef,Device ** out) const268 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
269 const NodeDef& ndef, Device** out) const {
270 DCHECK(out != nullptr);
271
272 PrioritizedDeviceTypeVector supported_devs;
273 auto device_type_list = prioritized_device_type_list();
274 TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
275 *device_type_list, ndef, &supported_devs, &HostCPU()->parsed_name()));
276 if (supported_devs.empty()) {
277 return errors::NotFound("Could not find device for node: ",
278 errors::FormatNodeNameForError(ndef.name()), " = ",
279 ndef.op(), "[", SummarizeAttrs(ndef), "]",
280 "\nAll kernels registered for op ", ndef.op(),
281 ":\n", KernelsRegisteredForOp(ndef.op()));
282 }
283
284 // Select the first matching registered device from the supported device
285 // list. If nothing matches and soft placement is enabled, pick a suitable
286 // device from the available ones.
287 const auto pflr_device_set = pflr()->device_set();
288 const PrioritizedDeviceVector& existing =
289 pflr_device_set->prioritized_devices();
290 *out = SelectBestMatchingDevice(preferred, existing, supported_devs);
291 if (*out != nullptr) {
292 return Status::OK();
293 }
294
295 if (AllowSoftPlacement()) {
296 DeviceNameUtils::ParsedName soft_device_name = preferred;
297 soft_device_name.type.clear();
298 soft_device_name.has_type = false;
299 soft_device_name.has_id = false;
300 // TODO(b/148213746): Soft placement logic picks up another task if the
301 // requested does not exist.
302 *out = SelectBestMatchingDevice(soft_device_name, existing, supported_devs);
303 if (*out != nullptr) {
304 return Status::OK();
305 }
306 }
307
308 if (DeviceNameUtils::HasSomeDetails(preferred)) {
309 return errors::InvalidArgument(
310 "Could not satisfy device specification '", preferred,
311 "'. enable_soft_placement=", AllowSoftPlacement(),
312 ". Supported device types [",
313 absl::StrJoin(DeviceTypesToString(supported_devs), ", "),
314 "]. All available devices [",
315 absl::StrJoin(DevicesToString(existing), ", "), "].");
316 }
317 return errors::InvalidArgument(
318 "No supported device found in available devices [",
319 absl::StrJoin(DevicesToString(existing), ", "),
320 "]. enable_soft_placement=", AllowSoftPlacement(),
321 ". Supported devices types [",
322 absl::StrJoin(DeviceTypesToString(supported_devs), ", "), "].");
323 }
324
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)325 void EagerContext::ResetClusterFLR(
326 DistributedFunctionLibraryRuntime* cluster_flr) {
327 cluster_flr_.Reset(cluster_flr, /*owned=*/true);
328 }
329
Executor()330 EagerExecutor& EagerContext::Executor() {
331 tf_shared_lock l(executor_map_mu_);
332 return *gtl::FindWithDefault(thread_local_executor_,
333 std::this_thread::get_id(), &default_executor_);
334 }
335
SetExecutorForThread(EagerExecutor * executor)336 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
337 tensorflow::mutex_lock l(executor_map_mu_);
338 if (executor == &default_executor_) {
339 thread_local_executor_.erase(std::this_thread::get_id());
340 } else {
341 auto thread_id = std::this_thread::get_id();
342 thread_local_executor_[thread_id] = executor;
343 auto& executors_with_cleanups = has_cleanup_[thread_id];
344 if (executors_with_cleanups.find(executor) ==
345 executors_with_cleanups.end()) {
346 executors_with_cleanups.insert(executor);
347 // If the executor is deleted before this context, we need to remove it
348 // from the map to avoid attempting to sync it in our destructor.
349 std::function<void()> cleanup([this, thread_id, executor]() {
350 {
351 tensorflow::mutex_lock l(executor_map_mu_);
352 auto existing = thread_local_executor_.find(thread_id);
353 if (existing != thread_local_executor_.end() &&
354 existing->second == executor) {
355 thread_local_executor_.erase(thread_id);
356 }
357 has_cleanup_[thread_id].erase(executor);
358 }
359 });
360 executor->AddCleanup(reinterpret_cast<intptr_t>(this),
361 std::move(cleanup));
362 }
363 }
364 }
365
ClearCachesAndThreadExecutors()366 void EagerContext::ClearCachesAndThreadExecutors() {
367 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
368 {
369 mutex_lock l(executor_map_mu_);
370 executors_copy = thread_local_executor_;
371 }
372 for (const auto& entry : executors_copy) {
373 entry.second->WaitForAllPendingNodes().IgnoreError();
374 }
375 ClearCachesAndDefaultExecutor();
376 }
377
ClearCachesAndDefaultExecutor()378 void EagerContext::ClearCachesAndDefaultExecutor() {
379 // The executor stores pointers to kernels, so we need to make sure that no
380 // async eager ops are still executing. We lock the cache during this time
381 // as well.
382 mutex_lock ml(cache_mu_);
383 default_executor_.WaitForAllPendingNodes().IgnoreError();
384 kernel_cache_.clear();
385 for (auto& entry : registered_functions_) {
386 entry.second->cached_kernel_keys->clear();
387 }
388 {
389 mutex_lock ml(metadata_mu_);
390 step_container_.reset(new ScopedStepContainer(
391 0, [this](const string& name) { ClearResourceContainer(name); }));
392 }
393 }
394
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)395 void EagerContext::SetThreadLocalDevicePlacementPolicy(
396 ContextDevicePlacementPolicy policy) {
397 mutex_lock ml(policy_map_mu_);
398 device_placement_policy_[std::this_thread::get_id()] = policy;
399 }
400
GetDevicePlacementPolicy() const401 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
402 tf_shared_lock l(policy_map_mu_);
403 auto policy_map_it =
404 device_placement_policy_.find(std::this_thread::get_id());
405 if (policy_map_it != device_placement_policy_.end()) {
406 return policy_map_it->second;
407 }
408 return default_device_placement_policy_;
409 }
410
411 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteContexts()412 std::vector<string> EagerContext::GetRemoteContexts() {
413 tf_shared_lock l(remote_state_mu_);
414 return remote_contexts_;
415 }
416
IsRemoteContextsEmpty()417 bool EagerContext::IsRemoteContextsEmpty() {
418 tf_shared_lock l(remote_state_mu_);
419 return remote_contexts_.empty();
420 }
421
CloseAndClearAllRemoteContexts()422 void EagerContext::CloseAndClearAllRemoteContexts() {
423 uint64 context_id;
424 uint64 context_view_id;
425 std::vector<string> remote_contexts_copy;
426 {
427 mutex_lock l(remote_state_mu_);
428 if (!is_master_) return;
429 context_id = context_id_;
430 context_view_id = context_view_id_;
431 context_id_ = kInvalidContextId;
432 // Forget the current view id and reset to the starting value 0.
433 context_view_id_ = 0;
434
435 // Make a copy of remote targets to avoid holding the lock when sending
436 // close context requests.
437 remote_contexts_copy = remote_contexts_;
438 remote_contexts_.clear();
439 }
440 CloseRemoteContexts(remote_contexts_copy, context_id, context_view_id);
441 }
442
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)443 void EagerContext::CloseRemoteContexts(
444 const std::vector<string>& remote_contexts, uint64 context_id,
445 uint64 context_view_id) {
446 // Close all remote contexts.
447 eager::CloseContextRequest request;
448 request.set_context_id(context_id);
449 request.set_context_view_id(context_view_id);
450 // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
451 // request to closed remote workers.
452 std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
453 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
454
455 int i = 0;
456 for (const auto& worker : remote_contexts) {
457 core::RefCountPtr<eager::EagerClient> client;
458 Status s = GetClient(worker, &client);
459
460 client->CloseContextAsync(
461 &request, &responses[i],
462 [&worker, &counter, context_id](const Status& s) {
463 if (!s.ok()) {
464 LOG(ERROR) << "Unable to close remote context with ID "
465 << context_id << " for worker: " << worker << " due to "
466 << s.error_message();
467 }
468 counter.DecrementCount();
469 });
470 i++;
471 }
472
473 counter.Wait();
474 }
475
476 #endif // !IS_MOBILE_PLATFORM
477
WaitForAndCloseRemoteContexts()478 void EagerContext::WaitForAndCloseRemoteContexts() {
479 ClearCachesAndThreadExecutors();
480
481 #if !defined(IS_MOBILE_PLATFORM)
482 {
483 mutex_lock l(keep_alive_thread_shutdown_mu_);
484 shutting_down_ = true;
485 keep_alive_thread_cv_.notify_all();
486 }
487 keep_alive_thread_.reset();
488
489 if (!IsRemoteContextsEmpty()) {
490 CloseAndClearAllRemoteContexts();
491 }
492
493 {
494 mutex_lock l(remote_state_mu_);
495
496 default_executor_.ShutDown().IgnoreError();
497 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
498 {
499 mutex_lock l(executor_map_mu_);
500 executors_copy = thread_local_executor_;
501 }
502 for (const auto& it : executors_copy) {
503 it.second->ShutDown().IgnoreError();
504 }
505
506 // This shuts down the completion queue and joins the thread polling it.
507 // The thread exits only after the completion queue has been drained of all
508 // the events. These events' completion should invoke all remaining RPC
509 // callbacks.
510 // This also deletes all EagerClient instances. There should not be any
511 // references to EagerClients left after all RPCs and async ops have been
512 // finished.
513 remote_eager_workers_ = nullptr;
514 }
515 #endif // !IS_MOBILE_PLATFORM
516 }
517
~EagerContext()518 EagerContext::~EagerContext() {
519 // TODO(iga): Add a separate API method to shutdown EagerContext so that we
520 // don't send RPCs and block in destructor.
521 WaitForAndCloseRemoteContexts();
522
523 // Custom devices may have obtained references to various context components
524 // (executors, thread pool). It's safer to run their destructors early.
525 custom_device_op_handler_.Clear();
526
527 ClearCachesAndThreadExecutors();
528 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
529 {
530 mutex_lock l(executor_map_mu_);
531 executors_copy = thread_local_executor_;
532 }
533 for (const auto& entry : executors_copy) {
534 // Let the executor know that its cleanup closure is no longer valid.
535 entry.second->RemoveCleanups(reinterpret_cast<intptr_t>(this));
536 }
537 for (auto& entry : registered_functions_) {
538 while (!entry.second->Unref()) {
539 // remove all references.
540 }
541 }
542 registered_functions_.clear();
543
544 #if !defined(IS_MOBILE_PLATFORM)
545 if (server_) {
546 // TODO(b/136478427): Fix this.
547 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
548 "Servers don't support clean shutdown.";
549 server_.release();
550 }
551
552 {
553 mutex_lock l(keep_alive_thread_shutdown_mu_);
554 shutting_down_ = true;
555 keep_alive_thread_cv_.notify_all();
556 }
557 keep_alive_thread_.reset();
558 if (!remote_contexts_.empty()) {
559 CloseAndClearAllRemoteContexts();
560 }
561 #endif // !IS_MOBILE_PLATFORM
562
563 if (rendezvous_) {
564 rendezvous_->Unref();
565 }
566 if (resource_deallocator_ != nullptr) {
567 resource_deallocator_();
568 }
569 }
570
FindFunctionByName(const string & name) const571 bool EagerContext::FindFunctionByName(const string& name) const {
572 return func_lib_def_.Find(name) != nullptr;
573 }
574
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)575 Status EagerContext::FindFunctionOpData(
576 const string& name, const tensorflow::OpRegistrationData** op_data) {
577 return func_lib_def_.LookUp(name, op_data);
578 }
579
FindFunctionDef(const string & name) const580 const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
581 return func_lib_def_.Find(name);
582 }
583
ExportRunMetadata()584 std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
585 mutex_lock ml(metadata_mu_);
586 auto result = std::make_unique<RunMetadata>();
587 run_metadata_.swap(result);
588 return result;
589 }
590
UsesTFRT()591 bool EagerContext::UsesTFRT() { return false; }
592
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)593 void EagerContext::ListDevices(
594 std::vector<tensorflow::DeviceAttributes>* devices) {
595 local_device_mgr()->ListDeviceAttributes(devices);
596 if (remote_device_mgr()) {
597 remote_device_mgr()->ListDeviceAttributes(devices);
598 }
599 }
600
StartStep()601 void EagerContext::StartStep() {
602 mutex_lock ml(metadata_mu_);
603 num_active_steps_++;
604 }
605
EndStep()606 void EagerContext::EndStep() {
607 mutex_lock ml(metadata_mu_);
608 num_active_steps_--;
609 if (num_active_steps_ == 0) {
610 // TODO(b/139809335): This does not properly clean up remote resources
611 // Clean up the previous step container and create a new one.
612 step_container_.reset(new ScopedStepContainer(
613 0, [this](const string& name) { ClearResourceContainer(name); }));
614 }
615 }
616
StepContainer()617 ScopedStepContainer* EagerContext::StepContainer() {
618 mutex_lock ml(metadata_mu_);
619 return step_container_.get();
620 }
621
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)622 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
623 // Only client context can register function on remote worker context.
624 if (!remote_device_manager_.Owned()) return Status::OK();
625 #if !defined(IS_MOBILE_PLATFORM)
626 std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
627 request->set_context_id(GetContextId());
628
629 eager::RegisterFunctionOp* register_function =
630 request->add_queue()->mutable_register_function();
631 *register_function->mutable_function_def() = fdef;
632 StripDefaultAttributes(
633 *OpRegistry::Global(),
634 register_function->mutable_function_def()->mutable_node_def());
635
636 auto remote_contexts = GetRemoteContexts();
637 for (const auto& target : remote_contexts) {
638 core::RefCountPtr<eager::EagerClient> eager_client;
639 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
640
641 eager::EnqueueResponse* response = new eager::EnqueueResponse();
642 eager_client->StreamingEnqueueAsync(
643 /*call_opts=*/nullptr, request.get(), response,
644 [request, response](const Status& status) {
645 if (!status.ok()) {
646 LOG(ERROR) << "Failed to register function remotely due to "
647 << status.error_message()
648 << "\nThis could happen if the remote target has been "
649 "disconnected from the client.";
650 }
651 delete response;
652 });
653 }
654 #endif // !IS_MOBILE_PLATFORM
655 return Status::OK();
656 }
657
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<string> & remote_workers)658 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
659 const std::vector<string>& remote_workers) {
660 #if !defined(IS_MOBILE_PLATFORM)
661 // Register multiple functions on selected remote workers.
662 uint64 context_id = GetContextId();
663 FunctionDefLibrary function_defs = func_lib_def_.ToProto();
664 std::vector<std::shared_ptr<eager::EnqueueRequest>> requests(
665 function_defs.function_size());
666 for (int i = 0; i < function_defs.function_size(); i++) {
667 requests[i] = std::make_shared<eager::EnqueueRequest>();
668 requests[i]->set_context_id(context_id);
669 eager::RegisterFunctionOp* register_function =
670 requests[i]->add_queue()->mutable_register_function();
671 *register_function->mutable_function_def() =
672 std::move(*function_defs.mutable_function(i));
673 StripDefaultAttributes(
674 *OpRegistry::Global(),
675 register_function->mutable_function_def()->mutable_node_def());
676 }
677
678 for (auto& remote_worker : remote_workers) {
679 core::RefCountPtr<eager::EagerClient> eager_client;
680 Status s = GetClient(remote_worker, &eager_client);
681 if (!s.ok()) {
682 continue;
683 }
684 for (int i = 0; i < requests.size(); i++) {
685 auto response = std::make_shared<eager::EnqueueResponse>();
686 eager_client->StreamingEnqueueAsync(
687 /*call_opts=*/nullptr, requests[i].get(), response.get(),
688 [request = requests[i], response](const Status& s) {
689 if (!s.ok()) {
690 LOG(ERROR) << "Failed to register function remotely due to "
691 << s.error_message()
692 << "\nThis could happen if the remote target has been "
693 "disconnected from the client.";
694 }
695 });
696 }
697 }
698 #endif // !IS_MOBILE_PLATFORM
699 return Status::OK();
700 }
701
AddFunctionDefWithStackTraces(const FunctionDef & fdef,const StackTracesMap & stack_traces)702 Status EagerContext::AddFunctionDefWithStackTraces(
703 const FunctionDef& fdef, const StackTracesMap& stack_traces) {
704 return AddFunctionDef(fdef, FunctionDefLibrary(),
705 /* add_to_local_only=*/false, stack_traces);
706 }
707
AddFunctionDef(const FunctionDef & fdef)708 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
709 return AddFunctionDef(fdef, FunctionDefLibrary(),
710 /* add_to_local_only=*/false);
711 }
712
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only,const StackTracesMap & stack_traces)713 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
714 const FunctionDefLibrary& library,
715 const bool add_to_local_only,
716 const StackTracesMap& stack_traces) {
717 bool is_first_ref = false;
718 {
719 mutex_lock l(cache_mu_);
720 auto* registered_function =
721 gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
722 if (registered_function == nullptr) {
723 registered_function = new RegisteredFunction;
724 registered_function->cached_kernel_keys =
725 absl::make_unique<std::vector<Fprint128>>();
726 gtl::InsertOrUpdate(®istered_functions_, fdef.signature().name(),
727 registered_function);
728 } else {
729 // The function has been registered before. If the function is the same,
730 // then we take a Ref() otherwise we error out.
731 const FunctionDef* prev_fdef =
732 func_lib_def_.Find(fdef.signature().name());
733 if (prev_fdef == nullptr) {
734 return errors::Internal("Function: ", fdef.signature().name(),
735 " is in the cache but not in the library");
736 }
737 if (!FunctionDefsEqual(fdef, *prev_fdef)) {
738 return errors::InvalidArgument(
739 "Attempting to add a duplicate function with name: ",
740 fdef.signature().name(), " where the previous and current ",
741 "definitions differ. Previous definition: ",
742 prev_fdef->DebugString(),
743 " and current definition: ", fdef.DebugString());
744 }
745 registered_function->Ref();
746 }
747 is_first_ref = registered_function->RefCountIsOne();
748 }
749 if (is_first_ref) {
750 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef, stack_traces));
751 TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
752 if (!add_to_local_only) {
753 return MaybeRegisterFunctionRemotely(fdef);
754 }
755 }
756 return Status::OK();
757 }
758
GetFunctionDef(const string & function_name)759 const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
760 return func_lib_def_.Find(function_name);
761 }
762
ListFunctionNames()763 std::vector<string> EagerContext::ListFunctionNames() {
764 return func_lib_def_.ListFunctionNames();
765 }
766
RemoveFunction(const string & func)767 Status EagerContext::RemoveFunction(const string& func) {
768 bool is_last_ref = false;
769 {
770 mutex_lock l(cache_mu_);
771 auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
772 if (registered_function == nullptr) {
773 return errors::InvalidArgument("Tried to remove non-existent function '",
774 func, "'.");
775 }
776 is_last_ref = registered_function->RefCountIsOne();
777 if (is_last_ref) {
778 for (auto& key : *registered_function->cached_kernel_keys) {
779 kernel_cache_.erase(key);
780 }
781 registered_functions_.erase(func);
782 }
783 registered_function->Unref();
784 }
785 if (is_last_ref) {
786 // TODO(fishx): Remove remote function as well.
787 return func_lib_def_.RemoveFunction(func);
788 }
789 return Status::OK();
790 }
791
SyncExecutors()792 Status EagerContext::SyncExecutors() {
793 StatusGroup sg;
794 // Synchronize on context default executor
795 sg.Update(default_executor_.WaitForAllPendingNodes());
796 default_executor_.ClearError();
797
798 // Synchronize thread local executors on client
799 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
800 {
801 mutex_lock l(executor_map_mu_);
802 executors_copy = thread_local_executor_;
803 }
804 for (const auto& entry : executors_copy) {
805 sg.Update(entry.second->WaitForAllPendingNodes());
806 entry.second->ClearError();
807 }
808
809 #if !defined(IS_MOBILE_PLATFORM)
810 auto remote_contexts = GetRemoteContexts();
811 // Synchronize executors on remote workers
812 eager::EnqueueRequest request;
813 request.set_context_id(GetContextId());
814 request.add_queue()->mutable_sync_remote_executor_for_stream();
815 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
816 std::vector<Status> statuses(remote_contexts.size());
817
818 for (int i = 0; i < remote_contexts.size(); i++) {
819 const auto& target = remote_contexts[i];
820 core::RefCountPtr<eager::EagerClient> eager_client;
821 TF_RETURN_IF_ERROR(GetClient(target, &eager_client));
822
823 eager::EnqueueResponse* response = new eager::EnqueueResponse();
824 eager_client->StreamingEnqueueAsync(
825 /*call_opts=*/nullptr, &request, response,
826 [response, target, &counter, &s = statuses[i]](const Status& status) {
827 s = status;
828 delete response;
829 counter.DecrementCount();
830 });
831 }
832 counter.Wait();
833 for (const Status& s : statuses) {
834 sg.Update(s);
835 }
836 #endif // !IS_MOBILE_PLATFORM
837 return sg.as_summary_status();
838 }
839
GetCachedKernel(Fprint128 cache_key)840 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
841 Fprint128 cache_key) {
842 tf_shared_lock l(cache_mu_);
843 auto iter = kernel_cache_.find(cache_key);
844 if (iter == kernel_cache_.end()) {
845 return nullptr;
846 }
847 core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
848 new_ref->Ref();
849 return new_ref;
850 }
851
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)852 void EagerContext::AddKernelToCache(Fprint128 cache_key,
853 KernelAndDevice* kernel) {
854 mutex_lock ml(cache_mu_);
855 core::RefCountPtr<KernelAndDevice> new_ref(kernel);
856 new_ref->Ref();
857 kernel_cache_[cache_key] = std::move(new_ref);
858 auto* registered_function =
859 gtl::FindPtrOrNull(registered_functions_, kernel->name());
860 // The kernel name can be either a primitive op or a function.
861 if (registered_function != nullptr) {
862 registered_function->cached_kernel_keys->emplace_back(cache_key);
863 }
864 }
865
ShouldStoreGraphs()866 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
867
SetShouldStoreGraphs(bool value)868 void EagerContext::SetShouldStoreGraphs(bool value) {
869 mutex_lock ml(metadata_mu_);
870 should_store_graphs_.store(value);
871 if (!value) {
872 run_metadata_.reset(new RunMetadata);
873 }
874 }
875
FindDeviceFromName(const char * device_name,Device ** device) const876 Status EagerContext::FindDeviceFromName(const char* device_name,
877 Device** device) const {
878 *device = HostCPU();
879 if (device_name == nullptr || strlen(device_name) == 0) {
880 return Status::OK();
881 }
882
883 auto status = local_device_mgr()->LookupDevice(device_name, device);
884 if (status.ok()) {
885 return status;
886 }
887
888 if (remote_device_mgr() != nullptr) {
889 return remote_device_mgr()->LookupDevice(device_name, device);
890 }
891
892 return status;
893 }
894
FindCompositeDeviceFromName(StringPiece device_name,CompositeDevice ** device) const895 Status EagerContext::FindCompositeDeviceFromName(
896 StringPiece device_name, CompositeDevice** device) const {
897 tf_shared_lock l(composite_devices_mu_);
898 for (const auto& d : composite_devices_) {
899 if (d.second->name() == device_name) {
900 *device = d.second.get();
901 return Status::OK();
902 }
903 }
904 return errors::NotFound("Unknown composite device: ", device_name);
905 }
906
RegisterCustomDevice(const string & device_name,std::unique_ptr<CustomDevice> device)907 Status EagerContext::RegisterCustomDevice(
908 const string& device_name, std::unique_ptr<CustomDevice> device) {
909 Device* existing_physical_device = nullptr;
910 if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
911 return errors::AlreadyExists(device_name,
912 " already registered as a physical device.");
913 }
914 return custom_device_op_handler_.RegisterCustomDevice(device_name,
915 std::move(device));
916 }
917
FindOrCreateCompositeDevice(const std::vector<string> & underlying_devices,const string & device_name,CompositeDevice ** composite_device)918 Status EagerContext::FindOrCreateCompositeDevice(
919 const std::vector<string>& underlying_devices, const string& device_name,
920 CompositeDevice** composite_device) {
921 if (!device_name.empty() &&
922 FindCompositeDeviceFromName(device_name, composite_device).ok()) {
923 return Status::OK();
924 }
925
926 const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
927
928 mutex_lock l(composite_devices_mu_);
929 auto iter = composite_devices_.find(hash_key);
930 if (iter != composite_devices_.end()) {
931 *composite_device = iter->second.get();
932 return Status::OK();
933 }
934
935 Status s;
936 std::unique_ptr<CompositeDevice> device;
937 if (device_name.empty()) {
938 // Create a CompositeDevice on the same task as the host CPU, in order to
939 // trigger packed TensorHandle copy from a client to a remote worker.
940 device = CompositeDevice::MakeDevice(underlying_devices,
941 composite_devices_.size(),
942 HostCPU()->parsed_name(), &s);
943 } else {
944 device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
945 }
946 TF_RETURN_IF_ERROR(s);
947 *composite_device = device.get();
948 pflr_->AddCompositeDevice(*composite_device);
949 composite_devices_.emplace(hash_key, std::move(device));
950 return Status::OK();
951 }
952
OnSameTask(const Device * first,const Device * second) const953 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
954 if (first == nullptr) first = HostCPU();
955 if (second == nullptr) second = HostCPU();
956 return first->parsed_name().job == second->parsed_name().job &&
957 first->parsed_name().replica == second->parsed_name().replica &&
958 first->parsed_name().task == second->parsed_name().task;
959 }
960
961 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const962 Status EagerContext::CPUDeviceOnTask(const Device* device,
963 Device** cpu_device) const {
964 string cpu_device_name;
965 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
966 device->name(), &cpu_device_name));
967
968 return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
969 }
970
ClearResourceContainer(const string & name)971 void EagerContext::ClearResourceContainer(const string& name) {
972 // TODO(b/139809335): This does not properly clean up remote resources
973 auto local_devices = local_device_mgr()->ListDevices();
974 for (Device* device : local_devices) {
975 // Only ignore container not found errors.
976 device->resource_manager()->Cleanup(name).IgnoreError();
977 }
978 }
979
980 namespace {
GetTaskName(Device * d,string * task_name)981 Status GetTaskName(Device* d, string* task_name) {
982 string ignored;
983 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
984 return errors::InvalidArgument("Unable to parse device name: ", d->name());
985 }
986
987 return Status::OK();
988 }
989 } // namespace
990
991 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)992 Status EagerContext::GetClient(Device* device,
993 core::RefCountPtr<eager::EagerClient>* client) {
994 return GetClient(device->parsed_name(), client);
995 }
996
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)997 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
998 core::RefCountPtr<eager::EagerClient>* client) {
999 string device_task_name;
1000 if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
1001 return errors::InvalidArgument(
1002 "Task is not fully specified in device name: ",
1003 DeviceNameUtils::ParsedNameToString(device_name));
1004 }
1005
1006 {
1007 tf_shared_lock l(remote_state_mu_);
1008 if (remote_eager_workers_ == nullptr) {
1009 return errors::Internal(
1010 "Haven't set up remote eager worker in this eager context yet.");
1011 }
1012 TF_RETURN_IF_ERROR(
1013 remote_eager_workers_->GetClient(device_task_name, client));
1014
1015 if (*client == nullptr) {
1016 return errors::InvalidArgument(
1017 "Unable to find eager client corresponding to device ",
1018 DeviceNameUtils::ParsedNameToString(device_name));
1019 }
1020 if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
1021 device_task_name) == remote_contexts_.end()) {
1022 return errors::Internal("Unable to find a context for handle on task: ",
1023 device_task_name, ". This should not happen.");
1024 }
1025 }
1026
1027 return Status::OK();
1028 }
1029
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)1030 Status EagerContext::GetClient(const string& remote_task,
1031 core::RefCountPtr<eager::EagerClient>* client) {
1032 {
1033 tf_shared_lock l(remote_state_mu_);
1034 if (remote_eager_workers_ == nullptr) {
1035 return errors::Internal(
1036 "Haven't set up remote eager worker in this eager context yet.");
1037 }
1038 TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
1039 }
1040
1041 if (*client == nullptr) {
1042 return errors::InvalidArgument(
1043 "Unable to find eager client corresponding to target ", remote_task);
1044 }
1045 return Status::OK();
1046 }
1047
GetContextId() const1048 uint64 EagerContext::GetContextId() const {
1049 tf_shared_lock l(remote_state_mu_);
1050 return context_id_;
1051 }
1052
GetContextViewId() const1053 uint64 EagerContext::GetContextViewId() const {
1054 tf_shared_lock l(remote_state_mu_);
1055 return context_view_id_;
1056 }
1057
IncrementContextViewId()1058 void EagerContext::IncrementContextViewId() {
1059 mutex_lock l(remote_state_mu_);
1060 context_view_id_ += 1;
1061 }
1062
1063 // Set collective ops related state in the context. Passing nullptr to
1064 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,const DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)1065 Status EagerContext::StoreCollectiveOpsServer(
1066 std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
1067 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
1068 collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
1069
1070 if (device_mgr != local_device_manager_.Get()) {
1071 if (local_device_manager_.Owned()) {
1072 old_local_device_managers_.push_back(
1073 std::move(local_device_manager_.owned_object));
1074 }
1075 local_device_manager_.Reset(device_mgr);
1076 }
1077 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1078
1079 if (reuse_rendezvous_for_functions_) {
1080 // If reuse_rendezvous_for_functions_ is true, CreateRendezvous is
1081 // idempotent and ignores its step_id argument. Create a rendezvous now to
1082 // replace the old one, preventing the old one from getting used.
1083 if (rendezvous_ != nullptr) rendezvous_->Unref();
1084 rendezvous_ = CreateRendezvous(/*step_id=*/-1);
1085 return errors::Aborted("Cannot create a valid rendezvous.");
1086 }
1087
1088 InitPrioritizedDeviceTypeList();
1089 ClearCachesAndThreadExecutors();
1090 default_executor_.ClearError();
1091 {
1092 tensorflow::mutex_lock l(executor_map_mu_);
1093 for (auto& entry : thread_local_executor_) {
1094 entry.second->ClearError();
1095 }
1096 }
1097
1098 const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
1099 ResetPFLR(
1100 local_device_manager_.Get(), env_, /*config=*/config,
1101 TF_GRAPH_DEF_VERSION, &func_lib_def_,
1102 /*optimizer_options=*/
1103 config ? config->graph_options().optimizer_options() : OptimizerOptions(),
1104 thread_pool_.get());
1105
1106 if (new_server != nullptr) {
1107 // Memory leak!
1108 if (server_ != nullptr) {
1109 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1110 "Servers don't support clean shutdown.";
1111 server_.release();
1112 }
1113 server_ = std::move(new_server);
1114 }
1115 DCHECK(server_ != nullptr);
1116
1117 return Status::OK();
1118 }
1119
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)1120 Status EagerContext::SetRemoteDeviceFilters(
1121 const string& remote_worker, const std::vector<string>& device_filters) {
1122 // Get fully specified task name for remote worker
1123 string remote_worker_task_name;
1124 DeviceNameUtils::ParsedName pw;
1125 if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
1126 return tensorflow::errors::InvalidArgument(
1127 "Remote worker task name is invalid ", remote_worker);
1128 }
1129 // Force set a replica as the key in cluster device filters map. I.e., if the
1130 // remote worker is `/job:worker/task:0` it then becomes
1131 // `/job:worker/replica:0/task:0`.
1132 pw.has_replica = true;
1133 if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
1134 return tensorflow::errors::InvalidArgument(
1135 "Job name and task index must be specified for worker ", remote_worker);
1136 }
1137
1138 std::vector<DeviceNameUtils::ParsedName> parsed_filters;
1139 for (auto& filter : device_filters) {
1140 DeviceNameUtils::ParsedName parsed_filter;
1141 if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
1142 parsed_filters.emplace_back(parsed_filter);
1143 } else {
1144 return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
1145 }
1146 }
1147
1148 if (VLOG_IS_ON(1)) {
1149 VLOG(1) << "Setting device filters for " << remote_worker << ":";
1150 for (auto& filter : device_filters) {
1151 VLOG(1) << " " << filter;
1152 }
1153 }
1154 mutex_lock l(remote_state_mu_);
1155 cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
1156 return Status::OK();
1157 }
1158
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)1159 void EagerContext::FilterDevicesForRemoteWorkers(
1160 const string& remote_worker,
1161 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
1162 std::vector<bool>* filtered_device_mask) {
1163 filtered_device_mask->resize(device_attrs.size());
1164 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
1165
1166 tf_shared_lock l(remote_state_mu_);
1167 auto it = cluster_device_filters_.find(remote_worker);
1168 // If no filters were specified, all devices should be visible to the worker
1169 if (it == cluster_device_filters_.end() || it->second.empty()) {
1170 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
1171 return;
1172 }
1173
1174 const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
1175 DeviceNameUtils::ParsedName parsed_remote_worker;
1176 DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
1177 for (int i = 0; i < device_attrs.size(); i++) {
1178 DeviceNameUtils::ParsedName pn;
1179 DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
1180 if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
1181 // If this device is on the remote worker itself, it should be visible
1182 // regardless of device filters
1183 filtered_device_mask->at(i) = true;
1184 continue;
1185 }
1186 for (const auto& pf : parsed_filters) {
1187 if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
1188 (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
1189 (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
1190 (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
1191 (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
1192 // Found a match, make it visible, stop processing more device filters
1193 filtered_device_mask->at(i) = true;
1194 break;
1195 }
1196 }
1197 }
1198 }
1199
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,const DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1200 Status EagerContext::InitializeRemoteMaster(
1201 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1202 std::shared_ptr<WorkerSession> worker_session,
1203 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1204 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
1205 const std::vector<string>& remote_contexts, uint64 context_id,
1206 Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
1207 DistributedFunctionLibraryRuntime* cluster_flr,
1208 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1209 remote_mgr) {
1210 if (context_id == kInvalidContextId) {
1211 return errors::InvalidArgument(
1212 "Failed to initialize remote for master context due to invalid ",
1213 "context id");
1214 }
1215
1216 if (!IsRemoteContextsEmpty()) {
1217 CloseAndClearAllRemoteContexts();
1218 }
1219 {
1220 mutex_lock l(remote_state_mu_);
1221 remote_contexts_ = remote_contexts;
1222 }
1223
1224 return SetMasterContextState(
1225 std::move(server), worker_env, std::move(worker_session),
1226 std::move(remote_eager_workers), std::move(remote_device_manager),
1227 context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
1228 std::move(remote_mgr));
1229 }
1230
UpdateRemoteMaster(uint64 context_id,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts)1231 Status EagerContext::UpdateRemoteMaster(
1232 uint64 context_id,
1233 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1234 const std::vector<string>& add_remote_contexts,
1235 const std::vector<string>& remove_remote_contexts) {
1236 {
1237 tf_shared_lock l(remote_state_mu_);
1238 if (context_id != context_id_) {
1239 return errors::InvalidArgument(
1240 "Failed to update remote master context due to invalid context id. ",
1241 "Request id = ", context_id, " but current id = ", context_id_);
1242 }
1243 }
1244
1245 if (!remove_remote_contexts.empty()) {
1246 // N.B. remove_remote_contexts include both removed and replaced workers.
1247 // In the case where a worker is replaced by one that resolves to the same
1248 // `hostname:port`, it is safe to close context with the current view id,
1249 // since the newly created context on the remote worker will be holding
1250 // a larger view id and ignores this request.
1251 CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
1252 mutex_lock l(remote_state_mu_);
1253 for (const string& remote_context : remove_remote_contexts) {
1254 remote_contexts_.erase(
1255 std::remove(remote_contexts_.begin(), remote_contexts_.end(),
1256 remote_context),
1257 remote_contexts_.end());
1258 }
1259 }
1260 if (!add_remote_contexts.empty()) {
1261 mutex_lock l(remote_state_mu_);
1262 remote_contexts_.insert(std::end(remote_contexts_),
1263 std::begin(add_remote_contexts),
1264 std::end(add_remote_contexts));
1265 }
1266
1267 {
1268 mutex_lock l(remote_state_mu_);
1269 context_view_id_++;
1270
1271 remote_eager_workers_ = std::move(remote_eager_workers);
1272 pflr_->InitializeDeviceAndFlr();
1273 InitPrioritizedDeviceTypeList();
1274
1275 default_executor_.ClearError();
1276 {
1277 tensorflow::mutex_lock l(executor_map_mu_);
1278 for (auto& entry : thread_local_executor_) {
1279 entry.second->ClearError();
1280 }
1281 }
1282 }
1283
1284 // Register existing functions to the newly added remote workers. Note that
1285 // this should happen only after updating `remote_contexts_` because new
1286 // functions might be registered while we update the context. When that
1287 // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1288 // register the new functions on all remote workers (including the newly added
1289 // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1290 // registering existing functions, where duplicate registrations will be
1291 // ignored by the remote workers.
1292 TF_RETURN_IF_ERROR(
1293 RegisterExistingFunctionsOnRemoteWorkers(add_remote_contexts));
1294 return Status::OK();
1295 }
1296
1297 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,const DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1298 Status EagerContext::SetMasterContextState(
1299 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1300 std::shared_ptr<WorkerSession> worker_session,
1301 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1302 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1303 uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr,
1304 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1305 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1306 remote_mgr) {
1307 mutex_lock l(remote_state_mu_);
1308 is_master_ = true;
1309 context_id_ = context_id;
1310 context_view_id_ = context_view_id;
1311
1312 use_send_tensor_rpc_ =
1313 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1314
1315 if (local_device_mgr != local_device_manager_.Get()) {
1316 if (local_device_manager_.Owned()) {
1317 old_local_device_managers_.push_back(
1318 std::move(local_device_manager_.owned_object));
1319 }
1320 local_device_manager_.Reset(local_device_mgr);
1321 }
1322 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1323
1324 if (rendezvous_ != nullptr) rendezvous_->Unref();
1325 rendezvous_ = r;
1326
1327 // Memory leak!
1328 if (server_ != nullptr) {
1329 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1330 "Servers don't support clean shutdown.";
1331 server_.release();
1332 }
1333 server_ = std::move(server);
1334
1335 remote_mgr_ = std::move(remote_mgr);
1336 worker_env_ = worker_env;
1337 worker_session_ = std::move(worker_session);
1338 remote_eager_workers_ = std::move(remote_eager_workers);
1339
1340 remote_device_manager_.Reset(std::move(remote_device_manager));
1341 ResetClusterFLR(cluster_flr);
1342
1343 InitPrioritizedDeviceTypeList();
1344
1345 ClearCachesAndThreadExecutors();
1346 default_executor_.ClearError();
1347 {
1348 tensorflow::mutex_lock l(executor_map_mu_);
1349 for (auto& entry : thread_local_executor_) {
1350 entry.second->ClearError();
1351 }
1352 }
1353 const auto* config = pflr_->config();
1354 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1355 &func_lib_def_, config->graph_options().optimizer_options(),
1356 thread_pool_.get(), cluster_flr_.Get());
1357
1358 keep_alive_secs_ = keep_alive_secs;
1359 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1360 // Only schedule a single closure.
1361 if (keep_alive_thread_ == nullptr) {
1362 keep_alive_thread_.reset(
1363 env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1364 while (true) {
1365 {
1366 {
1367 mutex_lock l(keep_alive_thread_shutdown_mu_);
1368
1369 if (shutting_down_) {
1370 return;
1371 }
1372
1373 keep_alive_thread_cv_.wait_for(
1374 l, std::chrono::seconds(sleep_for_secs_));
1375
1376 if (shutting_down_) {
1377 return;
1378 }
1379 }
1380 {
1381 mutex_lock l(remote_state_mu_);
1382 if (keep_alive_secs_ > 0) {
1383 {
1384 for (const auto& worker : remote_contexts_) {
1385 core::RefCountPtr<eager::EagerClient> client;
1386 Status s =
1387 remote_eager_workers_->GetClient(worker, &client);
1388
1389 if (!s.ok()) {
1390 LOG(WARNING) << "Keep-alive thread was unable to find "
1391 "a client for target "
1392 << worker << ". Got error: " << s;
1393 continue;
1394 }
1395
1396 eager::KeepAliveRequest* request =
1397 new eager::KeepAliveRequest;
1398 eager::KeepAliveResponse* response =
1399 new eager::KeepAliveResponse;
1400
1401 request->set_context_id(context_id_);
1402 client->KeepAliveAsync(
1403 request, response,
1404 [request, response](const Status& s) {
1405 delete request;
1406 delete response;
1407 });
1408 }
1409 }
1410 }
1411 }
1412 }
1413 }
1414 }));
1415 }
1416 return Status::OK();
1417 }
1418
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1419 Status EagerContext::InitializeRemoteWorker(
1420 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1421 DynamicDeviceMgr* remote_device_mgr,
1422 const std::vector<string>& remote_contexts, uint64 context_id,
1423 uint64 context_view_id,
1424 std::function<Rendezvous*(const int64)> rendezvous_creator,
1425 DistributedFunctionLibraryRuntime* cluster_flr,
1426 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1427 remote_mgr,
1428 std::function<void()> resource_deallocator) {
1429 if (context_id == kInvalidContextId) {
1430 return errors::InvalidArgument(
1431 "Failed to initialize remote for worker context due to invalid ",
1432 "context id");
1433 }
1434 mutex_lock l(remote_state_mu_);
1435
1436 if (remote_device_manager_.Owned() || server_ != nullptr ||
1437 keep_alive_thread_ != nullptr) {
1438 return errors::FailedPrecondition(
1439 "EagerContext::InitializeRemoteWorker Failed. ",
1440 "Already initialized remote as a master context.");
1441 }
1442 is_master_ = false;
1443
1444 remote_contexts_ = remote_contexts;
1445 context_id_ = context_id;
1446 context_view_id_ = context_view_id;
1447
1448 rendezvous_creator_ = std::move(rendezvous_creator);
1449 remote_eager_workers_ = std::move(remote_eager_workers);
1450 remote_mgr_ = std::move(remote_mgr);
1451 ResetClusterFLR(cluster_flr);
1452
1453 remote_device_manager_.Reset(remote_device_mgr);
1454
1455 const auto* config = pflr_->config();
1456 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1457 &func_lib_def_, config->graph_options().optimizer_options(),
1458 thread_pool_.get(), cluster_flr_.Get());
1459 InitPrioritizedDeviceTypeList();
1460
1461 ClearCachesAndThreadExecutors();
1462 default_executor_.ClearError();
1463 {
1464 tensorflow::mutex_lock l(executor_map_mu_);
1465 for (auto& entry : thread_local_executor_) {
1466 entry.second->ClearError();
1467 }
1468 }
1469
1470 resource_deallocator_ = std::move(resource_deallocator);
1471
1472 return Status::OK();
1473 }
1474
UpdateRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & remote_contexts,uint64 context_id)1475 Status EagerContext::UpdateRemoteWorker(
1476 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1477 const std::vector<string>& remote_contexts, uint64 context_id) {
1478 {
1479 mutex_lock l(remote_state_mu_);
1480 if (context_id != context_id_) {
1481 return errors::InvalidArgument(
1482 "Failed to update remote for worker context due to invalid ",
1483 "context id. Request id = ", context_id,
1484 " but current id = ", context_id_);
1485 }
1486 context_view_id_++;
1487
1488 remote_contexts_ = remote_contexts;
1489 remote_eager_workers_ = std::move(remote_eager_workers);
1490 InitPrioritizedDeviceTypeList();
1491 pflr_->InitializeDeviceAndFlr();
1492 }
1493
1494 // No need to update remote_device_manager_ since it's not owned for remote
1495 // worker context (owned by the corresponding worker session).
1496 if (remote_device_manager_.Owned()) {
1497 return errors::FailedPrecondition(
1498 "EagerContext::UpdateRemoteWorker failed because the context was "
1499 "initialized as a master context.");
1500 }
1501
1502 ClearCachesAndThreadExecutors();
1503 default_executor_.ClearError();
1504 {
1505 tensorflow::mutex_lock l(executor_map_mu_);
1506 for (auto& entry : thread_local_executor_) {
1507 entry.second->ClearError();
1508 }
1509 }
1510 return Status::OK();
1511 }
1512 #endif // !IS_MOBILE_PLATFORM
1513
1514 } // namespace tensorflow
1515