1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context.h"
17
18 #include <memory>
19 #include <vector>
20
21 // clang-format off
22 // Required for IS_MOBILE_PLATFORM
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/platform.h"
29 // clang-format on
30
31 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
32 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
33 #include "tensorflow/core/common_runtime/colocation_graph.h"
34 #include "tensorflow/core/common_runtime/device_resolver_local.h"
35 #include "tensorflow/core/common_runtime/device_set.h"
36 #include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
37 #include "tensorflow/core/common_runtime/process_util.h"
38 #include "tensorflow/core/framework/graph_def_util.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/public/version.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #if !defined(IS_MOBILE_PLATFORM)
44 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
45 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
46 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
47 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
48 #endif // !IS_MOBILE_PLATFORM
49 #include "tensorflow/core/framework/resource_mgr.h"
50 #include "tensorflow/core/lib/core/blocking_counter.h"
51 #include "tensorflow/core/lib/monitoring/gauge.h"
52 #include "tensorflow/core/platform/monitoring.h"
53 #include "tensorflow/core/util/env_var.h"
54
55 namespace tensorflow {
56 namespace {
57
ReadBoolFromEnvVar(StringPiece env_var_name,bool default_val)58 bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
59 bool val;
60 if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
61 return val;
62 }
63 return default_val;
64 }
65
66 auto* eager_context_created =
67 monitoring::Gauge<bool, 0>::New("/tensorflow/core/eager_context_created",
68 "True if an eager context was created.");
69
70 } // namespace
71
EagerContext(const SessionOptions & opts,ContextDevicePlacementPolicy default_device_placement_policy,ContextMirroringPolicy default_mirroring_policy,bool async,const bool lazy_copy_function_remote_inputs,const DeviceMgr * device_mgr,bool device_mgr_owned,Rendezvous * rendezvous,const CustomKernelCreator * custom_kernel_creator,DistributedFunctionLibraryRuntime * cluster_flr)72 EagerContext::EagerContext(
73 const SessionOptions& opts,
74 ContextDevicePlacementPolicy default_device_placement_policy,
75 ContextMirroringPolicy default_mirroring_policy, bool async,
76 const bool lazy_copy_function_remote_inputs, const DeviceMgr* device_mgr,
77 bool device_mgr_owned, Rendezvous* rendezvous,
78 const CustomKernelCreator* custom_kernel_creator,
79 DistributedFunctionLibraryRuntime* cluster_flr)
80 : default_device_placement_policy_(default_device_placement_policy),
81 default_mirroring_policy_(default_mirroring_policy),
82 local_device_manager_(device_mgr, device_mgr_owned),
83 host_cpu_device_(device_mgr->HostCPU()),
84 rendezvous_(rendezvous),
85 thread_pool_(NewThreadPoolFromSessionOptions(opts)),
86 custom_kernel_creator_(custom_kernel_creator),
87 cluster_flr_(cluster_flr),
88 log_device_placement_(opts.config.log_device_placement()),
89 allow_soft_placement_(opts.config.allow_soft_placement()),
90 num_active_steps_(0),
91 default_executor_(async),
92 log_memory_(LogMemory::IsEnabled()),
93 env_(opts.env),
94 lazy_copy_function_remote_inputs_(lazy_copy_function_remote_inputs),
95 use_send_tensor_rpc_(false),
96 pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
97 "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", false)) {
98 ResetPFLR(device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
99 &func_lib_def_, opts.config.graph_options().optimizer_options(),
100 thread_pool_.get(), cluster_flr, custom_kernel_creator_);
101 // Starts exporting metrics through a platform-specific monitoring API (if
102 // provided). For builds using "tensorflow/core/platform/default", this is
103 // currently a no-op.
104 eager_context_created->GetCell()->Set(true);
105 monitoring::StartExporter();
106 InitPrioritizedDeviceTypeList();
107 runner_ = [this](std::function<void()> closure) {
108 this->thread_pool_->Schedule(std::move(closure));
109 };
110
111 #if !defined(IS_MOBILE_PLATFORM)
112 context_id_ = kInvalidContextId;
113 #endif // IS_MOBILE_PLATFORM
114
115 std::unique_ptr<DeviceResolverInterface> drl(
116 new DeviceResolverLocal(local_device_mgr()));
117 std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
118 opts.config, local_device_mgr(), drl.get(),
119 "/job:localhost/replica:0/task:0"));
120 collective_executor_mgr_.Reset(
121 new CollectiveExecutorMgr(opts.config, local_device_mgr(), std::move(drl),
122 std::move(cprl)),
123 /*owned=*/true);
124 }
125
ResetPFLR(const DeviceMgr * device_mgr,Env * env,const ConfigProto * config,int graph_def_version,const FunctionLibraryDefinition * lib_def,const OptimizerOptions & optimizer_options,thread::ThreadPool * thread_pool,DistributedFunctionLibraryRuntime * cluster_flr,const CustomKernelCreator * custom_kernel_creator)126 void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
127 const ConfigProto* config, int graph_def_version,
128 const FunctionLibraryDefinition* lib_def,
129 const OptimizerOptions& optimizer_options,
130 thread::ThreadPool* thread_pool,
131 DistributedFunctionLibraryRuntime* cluster_flr,
132 const CustomKernelCreator* custom_kernel_creator) {
133 if (lazy_copy_function_remote_inputs_) {
134 pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
135 device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
136 thread_pool, cluster_flr, custom_kernel_creator));
137 } else {
138 pflr_.reset(new ProcessFunctionLibraryRuntime(
139 device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
140 thread_pool, cluster_flr, custom_kernel_creator));
141 }
142 }
143
InitPrioritizedDeviceTypeList()144 void EagerContext::InitPrioritizedDeviceTypeList() {
145 DeviceSet ds;
146 for (Device* d : local_device_mgr()->ListDevices()) {
147 ds.AddDevice(d);
148 }
149 auto remote_device_manager = remote_device_mgr();
150 if (remote_device_manager != nullptr) {
151 for (Device* d : remote_device_manager->ListDevices()) {
152 ds.AddDevice(d);
153 }
154 }
155 prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
156 }
157
158 namespace {
159 // Using absl::StrJoin with lambda does not work in tf-lite builds.
160 // TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
DevicesToString(const std::vector<Device * > & devices)161 std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
162 std::vector<string> v;
163 v.reserve(devices.size());
164 for (Device* d : devices) {
165 v.push_back(d->name());
166 }
167 return v;
168 }
169 } // namespace
170
SelectDevice(DeviceNameUtils::ParsedName preferred,const PrioritizedDeviceTypeVector & supported,const DataType dtype,Device ** device) const171 Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred,
172 const PrioritizedDeviceTypeVector& supported,
173 const DataType dtype, Device** device) const {
174 std::vector<Device*> selected;
175 const DeviceSet& pflr_devices = *pflr()->device_set();
176
177 // We always place string tensors on the CPU device if we're allowed to.
178 if (dtype == DT_STRING && AllowSoftPlacement()) {
179 preferred = HostCPU()->parsed_name();
180 }
181
182 // If there are no preferred devices, select the first registered device from
183 // the supported device list.
184 if (!DeviceNameUtils::HasSomeDetails(preferred)) {
185 // TODO(b/148213212): Allow setting default device in eager context.
186 selected = ColocationGraph::FilterSupportedDevices(
187 pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
188 if (selected.empty()) {
189 return errors::InvalidArgument(
190 "No supported device found in available devices [",
191 absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
192 }
193 *device = selected[0];
194 return Status::OK();
195 }
196
197 // If the caller specified a preferred device, select the first matching
198 // registered device from the supported device list. If nothing matches and
199 // soft placement is enabled, pick a suitable device from the available ones.
200 pflr_devices.FindMatchingDevices(preferred, &selected);
201
202 if (!selected.empty()) {
203 selected = ColocationGraph::FilterSupportedDevices(
204 selected, supported, /*default_local_device=*/nullptr);
205 }
206
207 if (selected.empty() && AllowSoftPlacement()) {
208 DeviceNameUtils::ParsedName soft_device_name = preferred;
209 soft_device_name.type.clear();
210 soft_device_name.has_type = false;
211 soft_device_name.has_id = false;
212 // TODO(b/148213746): Soft placement logic picks up another task if the
213 // requested does not exist.
214 pflr_devices.FindMatchingDevices(soft_device_name, &selected);
215 if (!selected.empty()) {
216 selected = ColocationGraph::FilterSupportedDevices(
217 selected, supported, /*default_local_device=*/nullptr);
218 }
219 }
220
221 if (selected.empty()) {
222 return errors::InvalidArgument(
223 "Could not satisfy device specification '", preferred,
224 "'. All available devices [",
225 absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
226 }
227
228 *device = selected[0];
229 return Status::OK();
230 }
231
ResetClusterFLR(DistributedFunctionLibraryRuntime * cluster_flr)232 void EagerContext::ResetClusterFLR(
233 DistributedFunctionLibraryRuntime* cluster_flr) {
234 cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
235 }
236
Executor()237 EagerExecutor& EagerContext::Executor() {
238 tf_shared_lock l(executor_map_mu_);
239 return *gtl::FindWithDefault(thread_local_executor_,
240 std::this_thread::get_id(), &default_executor_);
241 }
242
SetExecutorForThread(EagerExecutor * executor)243 void EagerContext::SetExecutorForThread(EagerExecutor* executor) {
244 tensorflow::mutex_lock l(executor_map_mu_);
245 if (executor == &default_executor_) {
246 thread_local_executor_.erase(std::this_thread::get_id());
247 } else {
248 thread_local_executor_[std::this_thread::get_id()] = executor;
249 }
250 }
251
ClearCachesAndThreadExecutors()252 void EagerContext::ClearCachesAndThreadExecutors() {
253 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
254 {
255 mutex_lock l(executor_map_mu_);
256 executors_copy = thread_local_executor_;
257 }
258 for (const auto& entry : executors_copy) {
259 entry.second->WaitForAllPendingNodes().IgnoreError();
260 }
261 ClearCachesAndDefaultExecutor();
262 }
263
ClearCachesAndDefaultExecutor()264 void EagerContext::ClearCachesAndDefaultExecutor() {
265 // The executor stores pointers to kernels, so we need to make sure that no
266 // async eager ops are still executing. We lock the cache during this time
267 // as well.
268 mutex_lock ml(cache_mu_);
269 default_executor_.WaitForAllPendingNodes().IgnoreError();
270 kernel_cache_.clear();
271 for (auto& entry : registered_functions_) {
272 entry.second->cached_kernel_keys->clear();
273 }
274 }
275
SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy)276 void EagerContext::SetThreadLocalDevicePlacementPolicy(
277 ContextDevicePlacementPolicy policy) {
278 mutex_lock ml(policy_map_mu_);
279 device_placement_policy_[std::this_thread::get_id()] = policy;
280 }
281
GetDevicePlacementPolicy() const282 ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() const {
283 tf_shared_lock l(policy_map_mu_);
284 auto policy_map_it =
285 device_placement_policy_.find(std::this_thread::get_id());
286 if (policy_map_it != device_placement_policy_.end()) {
287 return policy_map_it->second;
288 }
289 return default_device_placement_policy_;
290 }
291
SetThreadLocalMirroringPolicy(ContextMirroringPolicy policy)292 void EagerContext::SetThreadLocalMirroringPolicy(
293 ContextMirroringPolicy policy) {
294 mutex_lock ml(policy_map_mu_);
295 mirroring_policy_[std::this_thread::get_id()] = policy;
296 }
297
GetMirroringPolicy() const298 ContextMirroringPolicy EagerContext::GetMirroringPolicy() const {
299 tf_shared_lock l(policy_map_mu_);
300 auto policy_map_it = mirroring_policy_.find(std::this_thread::get_id());
301 if (policy_map_it != mirroring_policy_.end()) {
302 return policy_map_it->second;
303 }
304 return default_mirroring_policy_;
305 }
306
MirrorTensors() const307 bool EagerContext::MirrorTensors() const {
308 return GetMirroringPolicy() == MIRRORING_ALL;
309 }
310
LazyCopyFunctionRemoteInputs() const311 bool EagerContext::LazyCopyFunctionRemoteInputs() const {
312 return lazy_copy_function_remote_inputs_;
313 }
314
315 #if !defined(IS_MOBILE_PLATFORM)
CloseAndClearAllRemoteContexts()316 void EagerContext::CloseAndClearAllRemoteContexts() {
317 uint64 context_id;
318 uint64 context_view_id;
319 {
320 mutex_lock l(remote_state_mu_);
321 if (!is_master_) return;
322 context_id = context_id_;
323 context_view_id = context_view_id_;
324 context_id_ = kInvalidContextId;
325 // Forget the current view id and reset to the starting value 0.
326 context_view_id_ = 0;
327 }
328 CloseRemoteContexts(remote_contexts_, context_id, context_view_id);
329 remote_contexts_.clear();
330 }
331
CloseRemoteContexts(const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id)332 void EagerContext::CloseRemoteContexts(
333 const std::vector<string>& remote_contexts, uint64 context_id,
334 uint64 context_view_id) {
335 // Close all remote contexts.
336 eager::CloseContextRequest request;
337 request.set_context_id(context_id);
338 request.set_context_view_id(context_view_id);
339 // Setting context_id to a new value can avoid us issuing DestroyTensorHandle
340 // request to closed remote workers.
341 std::vector<eager::CloseContextResponse> responses(remote_contexts.size());
342 BlockingCounter counter(static_cast<int>(remote_contexts.size()));
343
344 int i = 0;
345 for (const auto& worker : remote_contexts) {
346 core::RefCountPtr<eager::EagerClient> client;
347 Status s = remote_eager_workers_->GetClient(worker, &client);
348
349 client->CloseContextAsync(
350 &request, &responses[i],
351 [&worker, &counter, context_id](const Status& s) {
352 if (!s.ok()) {
353 LOG(ERROR) << "Unable to close remote context with ID "
354 << context_id << " for worker: " << worker << " due to "
355 << s.error_message();
356 }
357 counter.DecrementCount();
358 });
359 i++;
360 }
361
362 counter.Wait();
363 }
364
365 #endif // !IS_MOBILE_PLATFORM
366
WaitForAndCloseRemoteContexts()367 void EagerContext::WaitForAndCloseRemoteContexts() {
368 ClearCachesAndThreadExecutors();
369
370 #if !defined(IS_MOBILE_PLATFORM)
371 {
372 mutex_lock l(keep_alive_thread_shutdown_mu_);
373 shutting_down_ = true;
374 keep_alive_thread_cv_.notify_all();
375 }
376 keep_alive_thread_.reset();
377
378 if (!remote_contexts_.empty()) {
379 CloseAndClearAllRemoteContexts();
380 }
381
382 {
383 mutex_lock l(remote_state_mu_);
384
385 default_executor_.ShutDown().IgnoreError();
386 std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
387 {
388 mutex_lock l(executor_map_mu_);
389 executors_copy = thread_local_executor_;
390 }
391 for (const auto& it : executors_copy) {
392 it.second->ShutDown().IgnoreError();
393 }
394 }
395
396 // This shuts down the completion queue and joins the thread polling it.
397 // The thread exits only after the completion queue has been drained of all
398 // the events. These events' completion should invoke all remaining RPC
399 // callbacks.
400 // This also deletes all EagerClient instances. There should not be any
401 // references to EagerClients left after all RPCs and async ops have been
402 // finished.
403 remote_eager_workers_ = nullptr;
404 #endif // !IS_MOBILE_PLATFORM
405 }
406
~EagerContext()407 EagerContext::~EagerContext() {
408 // TODO(iga): Add a separate API method to shutdown EagerContext so that we
409 // don't send RPCs and block in destructor.
410 WaitForAndCloseRemoteContexts();
411
412 ClearCachesAndThreadExecutors();
413 for (auto& entry : registered_functions_) {
414 while (!entry.second->Unref()) {
415 // remove all references.
416 }
417 }
418 registered_functions_.clear();
419
420 #if !defined(IS_MOBILE_PLATFORM)
421 if (server_) {
422 // TODO(b/136478427): Fix this.
423 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
424 "Servers don't support clean shutdown.";
425 server_.release();
426 }
427
428 {
429 mutex_lock l(keep_alive_thread_shutdown_mu_);
430 shutting_down_ = true;
431 keep_alive_thread_cv_.notify_all();
432 }
433 keep_alive_thread_.reset();
434 if (!remote_contexts_.empty()) {
435 CloseAndClearAllRemoteContexts();
436 }
437 #endif // !IS_MOBILE_PLATFORM
438
439 if (rendezvous_) {
440 rendezvous_->Unref();
441 }
442 if (resource_deallocator_ != nullptr) {
443 resource_deallocator_();
444 }
445 }
446
FindFunctionByName(const string & name) const447 bool EagerContext::FindFunctionByName(const string& name) const {
448 return func_lib_def_.Find(name) != nullptr;
449 }
450
FindFunctionOpData(const string & name,const tensorflow::OpRegistrationData ** op_data)451 Status EagerContext::FindFunctionOpData(
452 const string& name, const tensorflow::OpRegistrationData** op_data) {
453 return func_lib_def_.LookUp(name, op_data);
454 }
455
FindFunctionDef(const string & name)456 const FunctionDef* EagerContext::FindFunctionDef(const string& name) {
457 return func_lib_def_.Find(name);
458 }
459
ListRegisteredFunctions()460 std::vector<const FunctionDef*> EagerContext::ListRegisteredFunctions() {
461 std::vector<const FunctionDef*> result;
462 std::vector<string> function_names = func_lib_def_.ListFunctionNames();
463 result.reserve(function_names.size());
464 for (const string& fn : function_names) {
465 result.emplace_back(func_lib_def_.Find(fn));
466 }
467 return result;
468 }
469
ClearRunMetadata()470 void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
471
ListDevices(std::vector<tensorflow::DeviceAttributes> * devices)472 void EagerContext::ListDevices(
473 std::vector<tensorflow::DeviceAttributes>* devices) {
474 local_device_mgr()->ListDeviceAttributes(devices);
475 if (remote_device_mgr()) {
476 remote_device_mgr()->ListDeviceAttributes(devices);
477 }
478 }
479
StartStep()480 void EagerContext::StartStep() {
481 mutex_lock ml(metadata_mu_);
482 num_active_steps_++;
483 if (step_container_ == nullptr) {
484 step_container_.reset(
485 new ScopedStepContainer(0, [this](const string& name) {
486 auto local_devices = local_device_mgr()->ListDevices();
487 for (Device* device : local_devices) {
488 device->resource_manager()->Cleanup(name).IgnoreError();
489 }
490 }));
491 }
492 }
493
EndStep()494 void EagerContext::EndStep() {
495 mutex_lock ml(metadata_mu_);
496 num_active_steps_--;
497 if (num_active_steps_ == 0) {
498 step_container_.reset();
499 }
500 }
501
StepContainer()502 ScopedStepContainer* EagerContext::StepContainer() {
503 if (num_active_steps_.load() == 0) {
504 return nullptr;
505 }
506 mutex_lock ml(metadata_mu_);
507 return step_container_.get();
508 }
509
MaybeRegisterFunctionRemotely(const FunctionDef & fdef)510 Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
511 // Only client context can register function on remote worker context.
512 if (!remote_device_manager_.Owned()) return Status::OK();
513 #if !defined(IS_MOBILE_PLATFORM)
514 std::shared_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
515 request->set_context_id(GetContextId());
516
517 eager::RegisterFunctionOp* register_function =
518 request->add_queue()->mutable_register_function();
519 *register_function->mutable_function_def() = fdef;
520 StripDefaultAttributes(
521 *OpRegistry::Global(),
522 register_function->mutable_function_def()->mutable_node_def());
523
524 for (const auto& target : remote_contexts_) {
525 core::RefCountPtr<eager::EagerClient> eager_client;
526 TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
527
528 eager::EnqueueResponse* response = new eager::EnqueueResponse();
529 eager_client->StreamingEnqueueAsync(
530 request.get(), response, [request, response](const Status& status) {
531 if (!status.ok()) {
532 LOG(ERROR) << "Failed to register function remotely due to "
533 << status.error_message()
534 << "\nThis shouldn't happen, please file a bug to "
535 "tensorflow team.";
536 }
537 delete response;
538 });
539 }
540 #endif // !IS_MOBILE_PLATFORM
541 return Status::OK();
542 }
543
RegisterExistingFunctionsOnRemoteWorkers(const std::vector<const FunctionDef * > & function_defs,const std::vector<string> & remote_workers)544 Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers(
545 const std::vector<const FunctionDef*>& function_defs,
546 const std::vector<string>& remote_workers) {
547 #if !defined(IS_MOBILE_PLATFORM)
548 // Register multiple functions on selected remote workers.
549 uint64 context_id = GetContextId();
550 for (int i = 0; i < remote_workers.size(); i++) {
551 core::RefCountPtr<eager::EagerClient> eager_client;
552 Status s =
553 remote_eager_workers_->GetClient(remote_workers[i], &eager_client);
554 if (!s.ok()) {
555 continue;
556 }
557 for (int j = 0; j < function_defs.size(); j++) {
558 auto* request = new eager::EnqueueRequest;
559 request->set_context_id(context_id);
560 eager::RegisterFunctionOp* register_function =
561 request->add_queue()->mutable_register_function();
562 *register_function->mutable_function_def() = *function_defs[j];
563 auto* response = new eager::EnqueueResponse;
564 eager_client->StreamingEnqueueAsync(
565 request, response, [request, response](const Status& s) {
566 if (!s.ok()) {
567 LOG(ERROR) << "Failed to register function remotely due to "
568 << s.error_message()
569 << "\nThis shouldn't happen, please file a bug to "
570 "tensorflow team.";
571 }
572 delete request;
573 delete response;
574 });
575 }
576 }
577 #endif // !IS_MOBILE_PLATFORM
578 return Status::OK();
579 }
580
AddFunctionDef(const FunctionDef & fdef)581 Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
582 return AddFunctionDef(fdef, FunctionDefLibrary(),
583 /* add_to_local_only=*/false);
584 }
585
AddFunctionDef(const FunctionDef & fdef,const FunctionDefLibrary & library,const bool add_to_local_only)586 Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
587 const FunctionDefLibrary& library,
588 const bool add_to_local_only) {
589 bool is_first_ref = false;
590 {
591 mutex_lock l(cache_mu_);
592 auto* registered_function =
593 gtl::FindPtrOrNull(registered_functions_, fdef.signature().name());
594 if (registered_function == nullptr) {
595 registered_function = new RegisteredFunction;
596 registered_function->cached_kernel_keys =
597 absl::make_unique<std::vector<Fprint128>>();
598 gtl::InsertOrUpdate(®istered_functions_, fdef.signature().name(),
599 registered_function);
600 } else {
601 registered_function->Ref();
602 }
603 is_first_ref = registered_function->RefCountIsOne();
604 }
605 if (is_first_ref) {
606 TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
607 TF_RETURN_IF_ERROR(func_lib_def_.AddLibrary(library));
608 if (!add_to_local_only) {
609 return MaybeRegisterFunctionRemotely(fdef);
610 }
611 }
612 return Status::OK();
613 }
614
RemoveFunction(const string & func)615 Status EagerContext::RemoveFunction(const string& func) {
616 bool is_last_ref = false;
617 {
618 mutex_lock l(cache_mu_);
619 auto* registered_function = gtl::FindPtrOrNull(registered_functions_, func);
620 if (registered_function == nullptr) {
621 return errors::InvalidArgument("Tried to remove non-existent function '",
622 func, "'.");
623 }
624 is_last_ref = registered_function->RefCountIsOne();
625 if (is_last_ref) {
626 for (auto& key : *registered_function->cached_kernel_keys) {
627 kernel_cache_.erase(key);
628 }
629 registered_functions_.erase(func);
630 }
631 registered_function->Unref();
632 }
633 if (is_last_ref) {
634 // TODO(fishx): Remove remote function as well.
635 return func_lib_def_.RemoveFunction(func);
636 }
637 return Status::OK();
638 }
639
GetCachedKernel(Fprint128 cache_key)640 core::RefCountPtr<KernelAndDevice> EagerContext::GetCachedKernel(
641 Fprint128 cache_key) {
642 tf_shared_lock l(cache_mu_);
643 auto iter = kernel_cache_.find(cache_key);
644 if (iter == kernel_cache_.end()) {
645 return nullptr;
646 }
647 core::RefCountPtr<KernelAndDevice> new_ref(iter->second.get());
648 new_ref->Ref();
649 return new_ref;
650 }
651
AddKernelToCache(Fprint128 cache_key,KernelAndDevice * kernel)652 void EagerContext::AddKernelToCache(Fprint128 cache_key,
653 KernelAndDevice* kernel) {
654 mutex_lock ml(cache_mu_);
655 core::RefCountPtr<KernelAndDevice> new_ref(kernel);
656 new_ref->Ref();
657 kernel_cache_[cache_key] = std::move(new_ref);
658 auto* registered_function =
659 gtl::FindPtrOrNull(registered_functions_, kernel->name());
660 // The kernel name can be either a primitive op or a function.
661 if (registered_function != nullptr) {
662 registered_function->cached_kernel_keys->emplace_back(cache_key);
663 }
664 }
665
ShouldStoreGraphs()666 bool EagerContext::ShouldStoreGraphs() { return should_store_graphs_.load(); }
667
SetShouldStoreGraphs(bool value)668 void EagerContext::SetShouldStoreGraphs(bool value) {
669 mutex_lock ml(metadata_mu_);
670 should_store_graphs_.store(value);
671 if (!value) {
672 run_metadata_.Clear();
673 }
674 }
675
FindDeviceFromName(const char * device_name,Device ** device) const676 Status EagerContext::FindDeviceFromName(const char* device_name,
677 Device** device) const {
678 *device = HostCPU();
679 if (device_name == nullptr || strlen(device_name) == 0) {
680 return Status::OK();
681 }
682
683 auto status = local_device_mgr()->LookupDevice(device_name, device);
684 if (status.ok()) {
685 return status;
686 }
687
688 if (remote_device_mgr() != nullptr) {
689 return remote_device_mgr()->LookupDevice(device_name, device);
690 }
691
692 return status;
693 }
694
OnSameTask(const Device * first,const Device * second) const695 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
696 if (first == nullptr) first = HostCPU();
697 if (second == nullptr) second = HostCPU();
698 return first->parsed_name().job == second->parsed_name().job &&
699 first->parsed_name().replica == second->parsed_name().replica &&
700 first->parsed_name().task == second->parsed_name().task;
701 }
702
703 // Gets the CPU device on the task of device.
CPUDeviceOnTask(const Device * device,Device ** cpu_device) const704 Status EagerContext::CPUDeviceOnTask(const Device* device,
705 Device** cpu_device) const {
706 string cpu_device_name;
707 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
708 device->name(), &cpu_device_name));
709
710 return FindDeviceFromName(cpu_device_name.c_str(), cpu_device);
711 }
712
713 namespace {
GetTaskName(Device * d,string * task_name)714 Status GetTaskName(Device* d, string* task_name) {
715 string ignored;
716 if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) {
717 return errors::InvalidArgument("Unable to parse device name: ", d->name());
718 }
719
720 return Status::OK();
721 }
722 } // namespace
723
724 #if !defined(IS_MOBILE_PLATFORM)
GetClient(Device * device,core::RefCountPtr<eager::EagerClient> * client)725 Status EagerContext::GetClient(Device* device,
726 core::RefCountPtr<eager::EagerClient>* client) {
727 return GetClient(device->parsed_name(), client);
728 }
729
GetClient(const DeviceNameUtils::ParsedName & device_name,core::RefCountPtr<eager::EagerClient> * client)730 Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name,
731 core::RefCountPtr<eager::EagerClient>* client) {
732 if (remote_eager_workers_ == nullptr) {
733 return errors::Internal(
734 "Haven't set up remote eager worker in this eager context yet.");
735 }
736 string device_task_name;
737 if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) {
738 return errors::InvalidArgument(
739 "Task is not fully specified in device name: ",
740 DeviceNameUtils::ParsedNameToString(device_name));
741 }
742
743 TF_RETURN_IF_ERROR(
744 remote_eager_workers_->GetClient(device_task_name, client));
745
746 if (*client == nullptr) {
747 return errors::InvalidArgument(
748 "Unable to find eager client corresponding to device ",
749 DeviceNameUtils::ParsedNameToString(device_name));
750 }
751
752 if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
753 device_task_name) == remote_contexts_.end()) {
754 return errors::Internal("Unable to find a context for handle on task: ",
755 device_task_name, ". This should not be possible");
756 }
757
758 return Status::OK();
759 }
760
GetClient(const string & remote_task,core::RefCountPtr<eager::EagerClient> * client)761 Status EagerContext::GetClient(const string& remote_task,
762 core::RefCountPtr<eager::EagerClient>* client) {
763 if (remote_eager_workers_ == nullptr) {
764 return errors::Internal(
765 "Haven't set up remote eager worker in this eager context yet.");
766 }
767 TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(remote_task, client));
768
769 if (*client == nullptr) {
770 return errors::InvalidArgument(
771 "Unable to find eager client corresponding to target ", remote_task);
772 }
773 return Status::OK();
774 }
775
GetContextId()776 uint64 EagerContext::GetContextId() {
777 tf_shared_lock l(remote_state_mu_);
778 return context_id_;
779 }
780
GetContextViewId()781 uint64 EagerContext::GetContextViewId() {
782 tf_shared_lock l(remote_state_mu_);
783 return context_view_id_;
784 }
785
IncrementContextViewId()786 void EagerContext::IncrementContextViewId() {
787 mutex_lock l(remote_state_mu_);
788 context_view_id_ += 1;
789 }
790
791 // Set collective ops related state in the context. Passing nullptr to
792 // `new_server` will reuse the existing GRPC server in context.
StoreCollectiveOpsServer(std::unique_ptr<ServerInterface> new_server,DeviceMgr * device_mgr,CollectiveExecutorMgrInterface * rpc_collective_executor_mgr)793 Status EagerContext::StoreCollectiveOpsServer(
794 std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
795 CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
796 collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
797
798 local_device_manager_.Reset(device_mgr);
799 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
800
801 InitPrioritizedDeviceTypeList();
802 ClearCachesAndThreadExecutors();
803 default_executor_.ClearError();
804 {
805 tensorflow::mutex_lock l(executor_map_mu_);
806 for (auto& entry : thread_local_executor_) {
807 entry.second->ClearError();
808 }
809 }
810
811 const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
812 ResetPFLR(
813 local_device_manager_.Get(), env_, /*config=*/config,
814 TF_GRAPH_DEF_VERSION, &func_lib_def_,
815 /*optimizer_options=*/
816 config ? config->graph_options().optimizer_options() : OptimizerOptions(),
817 thread_pool_.get());
818
819 if (new_server != nullptr) {
820 // Memory leak!
821 if (server_ != nullptr) {
822 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
823 "Servers don't support clean shutdown.";
824 server_.release();
825 }
826 server_ = std::move(new_server);
827 }
828 DCHECK(server_ != nullptr);
829
830 return Status::OK();
831 }
832
SetRemoteDeviceFilters(const string & remote_worker,const std::vector<string> & device_filters)833 Status EagerContext::SetRemoteDeviceFilters(
834 const string& remote_worker, const std::vector<string>& device_filters) {
835 // Get fully specified task name for remote worker
836 string remote_worker_task_name;
837 DeviceNameUtils::ParsedName pw;
838 if (!DeviceNameUtils::ParseFullName(remote_worker, &pw)) {
839 return tensorflow::errors::InvalidArgument(
840 "Remote worker task name is invalid ", remote_worker);
841 }
842 // Force set a replica as the key in cluster device filters map. I.e., if the
843 // remote worker is `/job:worker/task:0` it then becomes
844 // `/job:worker/replica:0/task:0`.
845 pw.has_replica = true;
846 if (!DeviceNameUtils::GetTaskName(pw, &remote_worker_task_name)) {
847 return tensorflow::errors::InvalidArgument(
848 "Job name and task index must be specified for worker ", remote_worker);
849 }
850
851 std::vector<DeviceNameUtils::ParsedName> parsed_filters;
852 for (auto& filter : device_filters) {
853 DeviceNameUtils::ParsedName parsed_filter;
854 if (DeviceNameUtils::ParseFullName(filter, &parsed_filter)) {
855 parsed_filters.emplace_back(parsed_filter);
856 } else {
857 return tensorflow::errors::InvalidArgument("Invalid filter: ", filter);
858 }
859 }
860
861 if (VLOG_IS_ON(1)) {
862 VLOG(1) << "Setting device filters for " << remote_worker << ":";
863 for (auto& filter : device_filters) {
864 VLOG(1) << " " << filter;
865 }
866 }
867 mutex_lock l(remote_state_mu_);
868 cluster_device_filters_.emplace(remote_worker_task_name, parsed_filters);
869 return Status::OK();
870 }
871
FilterDevicesForRemoteWorkers(const string & remote_worker,const protobuf::RepeatedPtrField<DeviceAttributes> & device_attrs,std::vector<bool> * filtered_device_mask)872 void EagerContext::FilterDevicesForRemoteWorkers(
873 const string& remote_worker,
874 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attrs,
875 std::vector<bool>* filtered_device_mask) {
876 filtered_device_mask->resize(device_attrs.size());
877 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), false);
878
879 tf_shared_lock l(remote_state_mu_);
880 auto it = cluster_device_filters_.find(remote_worker);
881 // If no filters were specified, all devices should be visible to the worker
882 if (it == cluster_device_filters_.end() || it->second.empty()) {
883 std::fill(filtered_device_mask->begin(), filtered_device_mask->end(), true);
884 return;
885 }
886
887 const std::vector<DeviceNameUtils::ParsedName>& parsed_filters = it->second;
888 DeviceNameUtils::ParsedName parsed_remote_worker;
889 DeviceNameUtils::ParseFullName(remote_worker, &parsed_remote_worker);
890 for (int i = 0; i < device_attrs.size(); i++) {
891 DeviceNameUtils::ParsedName pn;
892 DeviceNameUtils::ParseFullName(device_attrs[i].name(), &pn);
893 if (DeviceNameUtils::IsSameAddressSpace(parsed_remote_worker, pn)) {
894 // If this device is on the remote worker itself, it should be visible
895 // regardless of device filters
896 filtered_device_mask->at(i) = true;
897 continue;
898 }
899 for (const auto& pf : parsed_filters) {
900 if ((!pn.has_job || !pf.has_job || pn.job == pf.job) &&
901 (!pn.has_replica || !pf.has_replica || pn.replica == pf.replica) &&
902 (!pn.has_task || !pf.has_task || pn.task == pf.task) &&
903 (!pn.has_type || !pf.has_type || pn.type == pf.type) &&
904 (!pn.has_id || !pf.has_id || pn.id == pf.id)) {
905 // Found a match, make it visible, stop processing more device filters
906 filtered_device_mask->at(i) = true;
907 break;
908 }
909 }
910 }
911 }
912
InitializeRemoteMaster(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,const std::vector<string> & remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)913 Status EagerContext::InitializeRemoteMaster(
914 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
915 std::shared_ptr<WorkerSession> worker_session,
916 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
917 std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
918 const std::vector<string>& remote_contexts, uint64 context_id,
919 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
920 DistributedFunctionLibraryRuntime* cluster_flr,
921 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
922 remote_mgr) {
923 if (context_id == kInvalidContextId) {
924 return errors::InvalidArgument(
925 "Failed to initialize remote for master context due to invalid ",
926 "context id");
927 }
928
929 if (!remote_contexts_.empty()) {
930 CloseAndClearAllRemoteContexts();
931 }
932 remote_contexts_ = remote_contexts;
933
934 return SetMasterContextState(
935 std::move(server), worker_env, std::move(worker_session),
936 std::move(remote_eager_workers), std::move(remote_device_manager),
937 context_id, 0, r, local_device_mgr, keep_alive_secs, cluster_flr,
938 std::move(remote_mgr));
939 }
940
UpdateRemoteMaster(WorkerEnv * worker_env,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,const std::vector<string> & add_remote_contexts,const std::vector<string> & remove_remote_contexts,uint64 context_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr)941 Status EagerContext::UpdateRemoteMaster(
942 WorkerEnv* worker_env,
943 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
944 const std::vector<string>& add_remote_contexts,
945 const std::vector<string>& remove_remote_contexts, uint64 context_id,
946 Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
947 DistributedFunctionLibraryRuntime* cluster_flr) {
948 {
949 tf_shared_lock l(remote_state_mu_);
950 if (context_id != context_id_) {
951 return errors::InvalidArgument(
952 "Failed to update remote remote master context due to invalid ",
953 "context id. Request id = ", context_id,
954 " but current id = ", context_id_);
955 }
956 }
957
958 if (!remove_remote_contexts.empty()) {
959 // N.B. remove_remote_contexts include both removed and replaced workers.
960 // In the case where a worker is replaced by one that resolves to the same
961 // `hostname:port`, it is safe to close context with the current view id,
962 // since the newly created context on the remote worker will be holding
963 // a larger view id and ignores this request.
964 CloseRemoteContexts(remove_remote_contexts, context_id, GetContextViewId());
965 for (const string& remote_context : remove_remote_contexts) {
966 remote_contexts_.erase(
967 std::remove(remote_contexts_.begin(), remote_contexts_.end(),
968 remote_context),
969 remote_contexts_.end());
970 }
971 }
972 if (!add_remote_contexts.empty()) {
973 remote_contexts_.insert(std::end(remote_contexts_),
974 std::begin(add_remote_contexts),
975 std::end(add_remote_contexts));
976 }
977 std::vector<const FunctionDef*> function_defs = ListRegisteredFunctions();
978
979 {
980 mutex_lock l(remote_state_mu_);
981 context_view_id_++;
982
983 worker_env_ = worker_env;
984 if (rendezvous_ != nullptr) rendezvous_->Unref();
985 rendezvous_ = r;
986 remote_eager_workers_ = std::move(remote_eager_workers);
987 ResetClusterFLR(cluster_flr);
988 InitPrioritizedDeviceTypeList();
989
990 default_executor_.ClearError();
991 {
992 tensorflow::mutex_lock l(executor_map_mu_);
993 for (auto& entry : thread_local_executor_) {
994 entry.second->ClearError();
995 }
996 }
997 const auto* config = pflr_->config();
998 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
999 &func_lib_def_, config->graph_options().optimizer_options(),
1000 thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1001 }
1002
1003 // Register existing functions to the newly added remote workers. Note that
1004 // this should happen only after updating `remote_contexts_` because new
1005 // functions might be registered while we update the context. When that
1006 // happens, this ordering ensures that `MaybeRegisterFunctionRemotely` will
1007 // register the new functions on all remote workers (including the newly added
1008 // ones), and `RegisterExistingFunctionsOnRemoteWorkers` will take care of
1009 // registering existing functions, where duplicate registrations will be
1010 // ignored by the remote workers.
1011 TF_RETURN_IF_ERROR(RegisterExistingFunctionsOnRemoteWorkers(
1012 function_defs, add_remote_contexts));
1013 return Status::OK();
1014 }
1015
1016 // Set distributed execution related state in the master context.
SetMasterContextState(std::unique_ptr<ServerInterface> server,WorkerEnv * worker_env,std::shared_ptr<WorkerSession> worker_session,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,std::unique_ptr<DynamicDeviceMgr> remote_device_manager,uint64 context_id,uint64 context_view_id,Rendezvous * r,DeviceMgr * local_device_mgr,int keep_alive_secs,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr)1017 Status EagerContext::SetMasterContextState(
1018 std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
1019 std::shared_ptr<WorkerSession> worker_session,
1020 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1021 std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
1022 uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
1023 int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
1024 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1025 remote_mgr) {
1026 mutex_lock l(remote_state_mu_);
1027 is_master_ = true;
1028 context_id_ = context_id;
1029 context_view_id_ = context_view_id;
1030
1031 use_send_tensor_rpc_ =
1032 ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
1033
1034 local_device_manager_.Reset(local_device_mgr);
1035 host_cpu_device_ = local_device_manager_.Get()->HostCPU();
1036
1037 if (rendezvous_ != nullptr) rendezvous_->Unref();
1038 rendezvous_ = r;
1039
1040 // Memory leak!
1041 if (server_ != nullptr) {
1042 LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
1043 "Servers don't support clean shutdown.";
1044 server_.release();
1045 }
1046 server_ = std::move(server);
1047
1048 remote_mgr_ = std::move(remote_mgr);
1049 worker_env_ = worker_env;
1050 worker_session_ = std::move(worker_session);
1051 remote_eager_workers_ = std::move(remote_eager_workers);
1052
1053 remote_device_manager_.Reset(std::move(remote_device_manager));
1054 ResetClusterFLR(cluster_flr);
1055
1056 InitPrioritizedDeviceTypeList();
1057
1058 ClearCachesAndThreadExecutors();
1059 default_executor_.ClearError();
1060 {
1061 tensorflow::mutex_lock l(executor_map_mu_);
1062 for (auto& entry : thread_local_executor_) {
1063 entry.second->ClearError();
1064 }
1065 }
1066 const auto* config = pflr_->config();
1067 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1068 &func_lib_def_, config->graph_options().optimizer_options(),
1069 thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1070
1071 keep_alive_secs_ = keep_alive_secs;
1072 sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
1073 // Only schedule a single closure.
1074 if (keep_alive_thread_ == nullptr) {
1075 keep_alive_thread_.reset(
1076 env_->StartThread({}, "EagerKeepAliveThread", [this]() {
1077 while (true) {
1078 {
1079 {
1080 mutex_lock l(keep_alive_thread_shutdown_mu_);
1081
1082 if (shutting_down_) {
1083 return;
1084 }
1085
1086 keep_alive_thread_cv_.wait_for(
1087 l, std::chrono::seconds(sleep_for_secs_));
1088
1089 if (shutting_down_) {
1090 return;
1091 }
1092 }
1093 {
1094 mutex_lock l(remote_state_mu_);
1095 if (keep_alive_secs_ > 0) {
1096 {
1097 for (const auto& worker : remote_contexts_) {
1098 core::RefCountPtr<eager::EagerClient> client;
1099 Status s =
1100 remote_eager_workers_->GetClient(worker, &client);
1101
1102 if (!s.ok()) {
1103 LOG(WARNING) << "Keep-alive thread was unable to find "
1104 "a client for target "
1105 << worker << ". Got error: " << s;
1106 continue;
1107 }
1108
1109 eager::KeepAliveRequest* request =
1110 new eager::KeepAliveRequest;
1111 eager::KeepAliveResponse* response =
1112 new eager::KeepAliveResponse;
1113
1114 request->set_context_id(context_id_);
1115 client->KeepAliveAsync(
1116 request, response,
1117 [request, response](const Status& s) {
1118 delete request;
1119 delete response;
1120 });
1121 }
1122 }
1123 }
1124 }
1125 }
1126 }
1127 }));
1128 }
1129 return Status::OK();
1130 }
1131
InitializeRemoteWorker(std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,uint64 context_view_id,std::function<Rendezvous * (const int64)> rendezvous_creator,DistributedFunctionLibraryRuntime * cluster_flr,std::unique_ptr<eager::RemoteMgr,std::function<void (eager::RemoteMgr *)>> remote_mgr,std::function<void ()> resource_deallocator)1132 Status EagerContext::InitializeRemoteWorker(
1133 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1134 DynamicDeviceMgr* remote_device_mgr,
1135 const std::vector<string>& remote_contexts, uint64 context_id,
1136 uint64 context_view_id,
1137 std::function<Rendezvous*(const int64)> rendezvous_creator,
1138 DistributedFunctionLibraryRuntime* cluster_flr,
1139 std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
1140 remote_mgr,
1141 std::function<void()> resource_deallocator) {
1142 if (context_id == kInvalidContextId) {
1143 return errors::InvalidArgument(
1144 "Failed to initialize remote for worker context due to invalid ",
1145 "context id");
1146 }
1147 mutex_lock l(remote_state_mu_);
1148
1149 if (remote_device_manager_.Owned() || server_ != nullptr ||
1150 keep_alive_thread_ != nullptr) {
1151 return errors::FailedPrecondition(
1152 "EagerContext::InitializeRemoteWorker Failed. ",
1153 "Already initialized remote as a master context.");
1154 }
1155 is_master_ = false;
1156
1157 remote_contexts_ = remote_contexts;
1158 context_id_ = context_id;
1159 context_view_id_ = context_view_id;
1160
1161 rendezvous_creator_ = std::move(rendezvous_creator);
1162 remote_eager_workers_ = std::move(remote_eager_workers);
1163 remote_mgr_ = std::move(remote_mgr);
1164 ResetClusterFLR(cluster_flr);
1165
1166 remote_device_manager_.Reset(remote_device_mgr);
1167
1168 const auto* config = pflr_->config();
1169 ResetPFLR(local_device_manager_.Get(), env_, config, TF_GRAPH_DEF_VERSION,
1170 &func_lib_def_, config->graph_options().optimizer_options(),
1171 thread_pool_.get(), cluster_flr_.Get(), custom_kernel_creator_);
1172 InitPrioritizedDeviceTypeList();
1173
1174 ClearCachesAndThreadExecutors();
1175 default_executor_.ClearError();
1176 {
1177 tensorflow::mutex_lock l(executor_map_mu_);
1178 for (auto& entry : thread_local_executor_) {
1179 entry.second->ClearError();
1180 }
1181 }
1182
1183 resource_deallocator_ = std::move(resource_deallocator);
1184
1185 return Status::OK();
1186 }
1187
UpdateRemoteWorker(const DeviceMgr * worker_session_device_mgr,std::unique_ptr<eager::EagerClientCache> remote_eager_workers,DynamicDeviceMgr * remote_device_mgr,const std::vector<string> & remote_contexts,uint64 context_id,DistributedFunctionLibraryRuntime * cluster_flr)1188 Status EagerContext::UpdateRemoteWorker(
1189 const DeviceMgr* worker_session_device_mgr,
1190 std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
1191 DynamicDeviceMgr* remote_device_mgr,
1192 const std::vector<string>& remote_contexts, uint64 context_id,
1193 DistributedFunctionLibraryRuntime* cluster_flr) {
1194 {
1195 mutex_lock l(remote_state_mu_);
1196 if (context_id != context_id_) {
1197 return errors::InvalidArgument(
1198 "Failed to update remote for worker context due to invalid ",
1199 "context id. Request id = ", context_id,
1200 " but current id = ", context_id_);
1201 }
1202 context_view_id_++;
1203 }
1204
1205 remote_contexts_ = remote_contexts;
1206
1207 remote_eager_workers_ = std::move(remote_eager_workers);
1208 ResetClusterFLR(cluster_flr);
1209
1210 remote_device_manager_.Reset(remote_device_mgr);
1211 InitPrioritizedDeviceTypeList();
1212
1213 ClearCachesAndThreadExecutors();
1214 default_executor_.ClearError();
1215 {
1216 tensorflow::mutex_lock l(executor_map_mu_);
1217 for (auto& entry : thread_local_executor_) {
1218 entry.second->ClearError();
1219 }
1220 }
1221
1222 SessionOptions options = SessionOptions();
1223 const auto* config = pflr_->config();
1224 ResetPFLR(worker_session_device_mgr, options.env, config,
1225 TF_GRAPH_DEF_VERSION, FuncLibDef(),
1226 config->graph_options().optimizer_options(), thread_pool_.get(),
1227 cluster_flr_.Get(), custom_kernel_creator_);
1228 return Status::OK();
1229 }
1230 #endif // !IS_MOBILE_PLATFORM
1231
1232 } // namespace tensorflow
1233