1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
17
18 #include "tensorflow/core/common_runtime/copy_tensor.h"
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/rendezvous.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/blocking_counter.h"
28 #include "tensorflow/core/platform/casts.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/notification.h"
32 #include "tensorflow/core/platform/platform.h"
33 #include "tensorflow/core/platform/refcount.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/protobuf/device_filters.pb.h"
36 #include "tensorflow/core/protobuf/error_codes.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
41 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
42 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
43 #include "tensorflow/core/distributed_runtime/remote_device.h"
44 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
45 #include "tensorflow/core/distributed_runtime/server_lib.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/distributed_runtime/worker_interface.h"
48 #endif // !IS_MOBILE_PLATFORM
49
50 namespace tensorflow {
51 #if !defined(IS_MOBILE_PLATFORM)
52 namespace {
AreLocalDevicesCompatible(const tensorflow::EagerContext * context,const tensorflow::ServerDef & server_def)53 bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
54 const tensorflow::ServerDef& server_def) {
55 if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
56 return false;
57 }
58 return server_def.default_session_config().SerializeAsString() ==
59 context->session_options().config.SerializeAsString();
60 }
61
AddRemoteDevicesToMgr(const std::vector<string> & added_remote_workers,tensorflow::WorkerCacheInterface * worker_cache,tensorflow::DynamicDeviceMgr * remote_device_mgr)62 tensorflow::Status AddRemoteDevicesToMgr(
63 const std::vector<string>& added_remote_workers,
64 tensorflow::WorkerCacheInterface* worker_cache,
65 tensorflow::DynamicDeviceMgr* remote_device_mgr) {
66 std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
67 tensorflow::mutex remote_devices_mu;
68 int num_added_workers = added_remote_workers.size();
69 tensorflow::BlockingCounter counter(num_added_workers);
70 std::vector<tensorflow::Status> statuses(num_added_workers);
71 for (int i = 0; i < num_added_workers; i++) {
72 tensorflow::NewRemoteDevices(
73 tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
74 [i, &statuses, &counter, &remote_devices, &remote_devices_mu](
75 const tensorflow::Status& s,
76 std::vector<tensorflow::Device*>* devices) {
77 statuses[i] = s;
78 if (s.ok()) {
79 tensorflow::mutex_lock l(remote_devices_mu);
80 for (tensorflow::Device* d : *devices) {
81 remote_devices.emplace_back(d);
82 }
83 }
84 counter.DecrementCount();
85 });
86 }
87 counter.Wait();
88 for (int i = 0; i < num_added_workers; i++) {
89 TF_RETURN_IF_ERROR(statuses[i]);
90 }
91
92 TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
93 return tensorflow::Status::OK();
94 }
95
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DynamicDeviceMgr> * device_mgr)96 tensorflow::Status GetAllRemoteDevices(
97 const std::vector<string>& remote_workers,
98 tensorflow::WorkerCacheInterface* worker_cache,
99 std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
100 auto remote_device_mgr = std::make_unique<tensorflow::DynamicDeviceMgr>();
101 TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
102 remote_device_mgr.get()));
103 *device_mgr = std::move(remote_device_mgr);
104 return tensorflow::Status::OK();
105 }
106
RemoveRemoteDevicesFromMgr(const std::vector<string> & removed_remote_workers,tensorflow::DynamicDeviceMgr * remote_device_mgr)107 tensorflow::Status RemoveRemoteDevicesFromMgr(
108 const std::vector<string>& removed_remote_workers,
109 tensorflow::DynamicDeviceMgr* remote_device_mgr) {
110 const std::vector<tensorflow::Device*> remote_devices =
111 (remote_device_mgr->ListDevices());
112 std::vector<tensorflow::Device*> devices_to_remove;
113 for (tensorflow::Device* d : remote_devices) {
114 for (const string& remote_worker : removed_remote_workers) {
115 if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
116 d->name())) {
117 devices_to_remove.emplace_back(d);
118 break;
119 }
120 }
121 }
122 TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
123 return tensorflow::Status::OK();
124 }
125
ListRemoteWorkers(tensorflow::ServerInterface * server,const string & local_worker,std::vector<string> * remote_workers)126 tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
127 const string& local_worker,
128 std::vector<string>* remote_workers) {
129 tensorflow::GrpcServer* grpc_server =
130 dynamic_cast<tensorflow::GrpcServer*>(server);
131 if (grpc_server == nullptr) {
132 return tensorflow::errors::Internal(
133 "Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
134 }
135 grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
136 remote_workers->erase(
137 std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
138 remote_workers->end());
139 return tensorflow::Status::OK();
140 }
141
DifferentiateWorkerLists(const std::vector<string> * current_list,const std::vector<string> * new_list,std::vector<string> * added,std::vector<string> * removed,std::vector<string> * existing)142 void DifferentiateWorkerLists(const std::vector<string>* current_list,
143 const std::vector<string>* new_list,
144 std::vector<string>* added,
145 std::vector<string>* removed,
146 std::vector<string>* existing) {
147 // Get STL set_difference and set_intersection with one list traversal.
148 // Similar to the set_difference library function, the input lists
149 // (`current_list` and `new_list`) must be sorted before calling the function.
150 added->resize(new_list->size());
151 removed->resize(current_list->size());
152 existing->resize(current_list->size());
153 std::vector<string>::const_iterator curr_it = current_list->begin();
154 std::vector<string>::const_iterator new_it = new_list->begin();
155 std::vector<string>::iterator added_it = added->begin();
156 std::vector<string>::iterator removed_it = removed->begin();
157 std::vector<string>::iterator existing_it = existing->begin();
158 while (curr_it != current_list->end() && new_it != new_list->end()) {
159 if (*curr_it < *new_it) {
160 *removed_it++ = *curr_it++;
161 } else if (*curr_it > *new_it) {
162 *added_it++ = *new_it++;
163 } else {
164 *existing_it++ = *curr_it++;
165 new_it++;
166 }
167 }
168 removed_it = std::copy(curr_it, current_list->end(), removed_it);
169 added_it = std::copy(new_it, new_list->end(), added_it);
170 added->resize(added_it - added->begin());
171 removed->resize(removed_it - removed->begin());
172 existing->resize(existing_it - existing->begin());
173 }
174
GetReplacedFromExistingWorkers(const std::vector<string> * existing_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * client_cache,std::vector<string> * replaced_workers)175 tensorflow::Status GetReplacedFromExistingWorkers(
176 const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
177 tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
178 tensorflow::eager::EagerClientCache* client_cache,
179 std::vector<string>* replaced_workers) {
180 tensorflow::BlockingCounter counter(existing_workers->size());
181 std::vector<tensorflow::Status> statuses(existing_workers->size());
182 tensorflow::eager::KeepAliveRequest request;
183 request.set_context_id(context_id);
184 std::vector<tensorflow::eager::KeepAliveResponse> responses(
185 existing_workers->size());
186 for (int i = 0; i < existing_workers->size(); i++) {
187 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
188 statuses[i] =
189 client_cache->GetClient(existing_workers->at(i), &eager_client);
190 if (!statuses[i].ok()) {
191 counter.DecrementCount();
192 continue;
193 }
194 eager_client->KeepAliveAsync(
195 &request, &responses[i],
196 [i, &statuses, &counter](const tensorflow::Status& s) {
197 statuses[i] = s;
198 counter.DecrementCount();
199 });
200 }
201 counter.Wait();
202 for (int i = 0; i < existing_workers->size(); i++) {
203 // If the RPC fails (indicating that the requested ID doesn't exist on
204 // remote), or the returned view ID is not equal to the local one
205 // (indicating that the remote worker has a stale view of cluster), treat
206 // the worker as replaced.
207 if (!statuses[i].ok() ||
208 responses[i].context_view_id() != context_view_id) {
209 replaced_workers->emplace_back(existing_workers->at(i));
210 }
211 }
212 return tensorflow::Status::OK();
213 }
214
CreateRemoteContexts(EagerContext * context,const std::vector<string> & remote_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,const tensorflow::eager::CreateContextRequest & base_request)215 tensorflow::Status CreateRemoteContexts(
216 EagerContext* context, const std::vector<string>& remote_workers,
217 tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
218 int keep_alive_secs, const tensorflow::ServerDef& server_def,
219 tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
220 const tensorflow::eager::CreateContextRequest& base_request) {
221 int num_remote_workers = remote_workers.size();
222 tensorflow::BlockingCounter counter(num_remote_workers);
223 std::vector<tensorflow::Status> statuses(num_remote_workers);
224 for (int i = 0; i < num_remote_workers; i++) {
225 const string& remote_worker = remote_workers[i];
226 tensorflow::DeviceNameUtils::ParsedName parsed_name;
227 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
228 &parsed_name)) {
229 statuses[i] = tensorflow::errors::InvalidArgument(
230 "Unable to parse ", remote_worker, " as a device name");
231 counter.DecrementCount();
232 continue;
233 }
234
235 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
236 statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
237 if (eager_client == nullptr) {
238 statuses[i] = tensorflow::errors::Internal(
239 "Cannot find a client for the given target:", remote_worker);
240 }
241 if (!statuses[i].ok()) {
242 counter.DecrementCount();
243 continue;
244 }
245
246 tensorflow::eager::CreateContextRequest request;
247 tensorflow::eager::CreateContextResponse* response =
248 new tensorflow::eager::CreateContextResponse();
249 request.set_context_id(context_id);
250 request.set_context_view_id(context_view_id);
251 *request.mutable_server_def() = server_def;
252 request.mutable_server_def()->set_job_name(parsed_name.job);
253 request.mutable_server_def()->set_task_index(parsed_name.task);
254 request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
255 server_def.default_session_config());
256
257 std::vector<bool> filtered_device_mask;
258 context->FilterDevicesForRemoteWorkers(
259 remote_worker, base_request.cluster_device_attributes(),
260 &filtered_device_mask);
261 DCHECK_EQ(filtered_device_mask.size(),
262 base_request.cluster_device_attributes_size());
263 for (int i = 0; i < filtered_device_mask.size(); i++) {
264 if (filtered_device_mask[i]) {
265 const auto& da = base_request.cluster_device_attributes(i);
266 *request.add_cluster_device_attributes() = da;
267 }
268 }
269 request.set_async(async);
270 request.set_keep_alive_secs(keep_alive_secs);
271 // TODO(b/134094971): deprecate lazy_copy_remote_function_inputs when server
272 // doesn't try to get the value of lazy_copy_remote_function_inputs.
273 request.set_lazy_copy_remote_function_inputs(true);
274
275 eager_client->CreateContextAsync(
276 &request, response,
277 [i, &statuses, &counter, response](const tensorflow::Status& s) {
278 statuses[i] = s;
279 delete response;
280 counter.DecrementCount();
281 });
282 }
283 counter.Wait();
284 tensorflow::StatusGroup sg;
285 for (int i = 0; i < num_remote_workers; i++) {
286 if (TF_PREDICT_FALSE(!statuses[i].ok())) {
287 sg.Update(statuses[i]);
288 }
289 }
290 return sg.as_summary_status();
291 }
292
UpdateRemoteContexts(EagerContext * context,const std::vector<string> & remote_workers,const std::vector<string> & added_workers,const std::vector<string> & removed_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,const tensorflow::eager::CreateContextRequest & base_request)293 tensorflow::Status UpdateRemoteContexts(
294 EagerContext* context, const std::vector<string>& remote_workers,
295 const std::vector<string>& added_workers,
296 const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
297 tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
298 tensorflow::eager::EagerClientCache* remote_eager_workers,
299 const tensorflow::eager::CreateContextRequest& base_request) {
300 int num_remote_workers = remote_workers.size();
301 tensorflow::BlockingCounter counter(num_remote_workers);
302 std::vector<tensorflow::Status> statuses(num_remote_workers);
303
304 int cluster_device_count = base_request.cluster_device_attributes_size();
305 std::unordered_set<string> added_or_removed(added_workers.begin(),
306 added_workers.end());
307 std::copy(removed_workers.begin(), removed_workers.end(),
308 std::inserter(added_or_removed, added_or_removed.end()));
309 // Whether each device is in the updated (added or removed) workers
310 std::vector<bool> device_added_or_removed(cluster_device_count);
311 for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
312 const auto& da = base_request.cluster_device_attributes().at(i);
313 tensorflow::DeviceNameUtils::ParsedName pn;
314 tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
315 string task_name;
316 tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
317 if (added_or_removed.find(task_name) != added_or_removed.end()) {
318 device_added_or_removed[i] = true;
319 }
320 }
321
322 for (int i = 0; i < num_remote_workers; i++) {
323 const string& remote_worker = remote_workers[i];
324 tensorflow::DeviceNameUtils::ParsedName parsed_name;
325 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
326 &parsed_name)) {
327 statuses[i] = tensorflow::errors::InvalidArgument(
328 "Unable to parse ", remote_worker, " as a device name");
329 counter.DecrementCount();
330 continue;
331 }
332
333 tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
334 statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
335 if (eager_client == nullptr) {
336 statuses[i] = tensorflow::errors::Internal(
337 "Cannot find a client for the given target:", remote_worker);
338 }
339 if (!statuses[i].ok()) {
340 counter.DecrementCount();
341 continue;
342 }
343
344 std::vector<bool> filtered_device_mask;
345 context->FilterDevicesForRemoteWorkers(
346 remote_worker, base_request.cluster_device_attributes(),
347 &filtered_device_mask);
348 DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
349
350 // If any of the devices that match the device filters are in the set of
351 // added or removed workers, we must send a complete UpdateContextRequest.
352 // Otherwise, only send a simple request to increment context view ID.
353 std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
354 std::transform(device_added_or_removed.begin(),
355 device_added_or_removed.end(), filtered_device_mask.begin(),
356 added_or_removed_filtered_devices.begin(),
357 std::logical_and<bool>());
358 const bool full_update_request =
359 std::accumulate(added_or_removed_filtered_devices.begin(),
360 added_or_removed_filtered_devices.end(), false,
361 std::logical_or<bool>());
362
363 tensorflow::eager::UpdateContextRequest request;
364 auto* response = new tensorflow::eager::UpdateContextResponse();
365 request.set_context_id(context_id);
366 request.set_context_view_id(context_view_id);
367 if (full_update_request) {
368 *request.mutable_server_def() = server_def;
369 request.mutable_server_def()->set_job_name(parsed_name.job);
370 request.mutable_server_def()->set_task_index(parsed_name.task);
371 request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
372 server_def.default_session_config());
373 for (int i = 0; i < cluster_device_count; i++) {
374 if (filtered_device_mask[i]) {
375 const auto& da = base_request.cluster_device_attributes(i);
376 *request.add_cluster_device_attributes() = da;
377 }
378 }
379 }
380
381 eager_client->UpdateContextAsync(
382 &request, response,
383 [i, &statuses, &counter, response](const tensorflow::Status& s) {
384 statuses[i] = s;
385 delete response;
386 counter.DecrementCount();
387 });
388 }
389 counter.Wait();
390 for (int i = 0; i < num_remote_workers; i++) {
391 TF_RETURN_IF_ERROR(statuses[i]);
392 }
393 return tensorflow::Status::OK();
394 }
395
UpdateContextWithServerDef(EagerContext * context,const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)396 tensorflow::Status UpdateContextWithServerDef(
397 EagerContext* context, const tensorflow::ServerDef& server_def,
398 bool reset_context, int keep_alive_secs) {
399 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
400 // server object (which currently CHECK-fails) and we miss the error, instead,
401 // we log the error, and then return to allow the user to see the error
402 // message.
403 #define LOG_AND_RETURN_IF_ERROR(...) \
404 do { \
405 const ::tensorflow::Status _status = (__VA_ARGS__); \
406 if (TF_PREDICT_FALSE(!_status.ok())) { \
407 LOG(ERROR) << _status.error_message(); \
408 return _status; \
409 } \
410 } while (0);
411
412 string worker_name =
413 tensorflow::strings::StrCat("/job:", server_def.job_name(),
414 "/replica:0/task:", server_def.task_index());
415
416 // List of current remote workers before updating server_def. Unused if
417 // resetting the server_def.
418 std::vector<string> curr_remote_workers;
419 // List of updated remote workers.
420 std::vector<string> remote_workers;
421
422 // New server created for new server_def. Unused if updating server_def.
423 std::unique_ptr<tensorflow::ServerInterface> new_server;
424 tensorflow::GrpcServer* grpc_server;
425 if (reset_context) {
426 const tensorflow::DeviceMgr* device_mgr =
427 AreLocalDevicesCompatible(context, server_def)
428 ? context->local_device_mgr()
429 : nullptr;
430 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
431 server_def, {device_mgr}, &new_server));
432 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
433 LOG_AND_RETURN_IF_ERROR(
434 ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
435 } else {
436 LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
437 &curr_remote_workers));
438 // No need to check the cast here, since `ListRemoteWorkers` already checks
439 // if the server is a GRPC server or not.
440 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
441 LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
442 LOG_AND_RETURN_IF_ERROR(
443 ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
444 }
445
446 tensorflow::uint64 context_id = context->GetContextId();
447 tensorflow::uint64 context_view_id = context->GetContextViewId();
448 if (reset_context) {
449 context_id = tensorflow::EagerContext::NewContextId();
450 context_view_id = 0;
451 // Make master eager context accessible by local eager service, which might
452 // receive send tensor requests from remote workers.
453 LOG_AND_RETURN_IF_ERROR(
454 grpc_server->AddMasterEagerContextToEagerService(context_id, context));
455 }
456
457 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
458 LOG_AND_RETURN_IF_ERROR(
459 grpc_server->master_env()->worker_cache->GetEagerClientCache(
460 &remote_eager_workers));
461
462 // For cluster update, use a status group to aggregate statuses from
463 // * adding and removing remote devices
464 // * creating remote contexts on newly added workers
465 // * updating remote contexts on existing workers
466 // * updating the master context
467 // Note that we should not return immediately on errors in the middle of these
468 // updates to prevent cluster from having inconsistent context views.
469 //
470 // Unused if `reset_context` is True.
471 tensorflow::StatusGroup sg;
472
473 // When updating an existing context, populate the following lists with:
474 // * added_workers: set(remote_workers) - set(curr_remote_workers)
475 // * removed_workers: set(curr_remote_workers) - set(remote_workers)
476 // * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
477 // * replaced_workers: workers with the same task names and potentially the
478 // same `hostname:port`s, but replaced by different processes
479 std::vector<string> added_workers;
480 std::vector<string> removed_workers;
481 std::vector<string> existing_workers;
482 std::vector<string> replaced_workers;
483
484 // New remote device manager created for new server_def. Unused if updating
485 // server_def.
486 std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
487 tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
488 if (reset_context) {
489 LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
490 remote_workers, grpc_server->master_env()->worker_cache,
491 &new_remote_device_mgr));
492 remote_device_mgr = new_remote_device_mgr.get();
493 } else {
494 context->ClearCachesAndDefaultExecutor();
495 // TODO(b/143914772): Potential memory leak if rendezvous has pending
496 // tensors for removed / replaced workers.
497
498 remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
499 if (remote_device_mgr == nullptr) {
500 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
501 "Updating context with an invalid set of remote devices."));
502 }
503 std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
504 std::sort(remote_workers.begin(), remote_workers.end());
505 DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
506 &added_workers, &removed_workers,
507 &existing_workers);
508 sg.Update(GetReplacedFromExistingWorkers(
509 &existing_workers, context_id, context->GetContextViewId(), server_def,
510 remote_eager_workers.get(), &replaced_workers));
511 if (VLOG_IS_ON(1)) {
512 VLOG(1) << "Updating cluster with following changes";
513 for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
514 for (const string& w : removed_workers)
515 VLOG(1) << " Removed worker " << w;
516 for (const string& w : replaced_workers)
517 VLOG(1) << " Replaced worker " << w;
518 }
519 if (!replaced_workers.empty()) {
520 // Treat replaced workers as removed then added back, so that we recreate
521 // remote devices and contexts, and re-register functions on those workers
522 removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
523 replaced_workers.end());
524 added_workers.insert(added_workers.end(), replaced_workers.begin(),
525 replaced_workers.end());
526 for (const string& w : replaced_workers) {
527 existing_workers.erase(
528 std::remove(existing_workers.begin(), existing_workers.end(), w),
529 existing_workers.end());
530 }
531 }
532 sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
533 sg.Update(AddRemoteDevicesToMgr(added_workers,
534 grpc_server->master_env()->worker_cache,
535 remote_device_mgr));
536 }
537
538 std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
539 remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
540
541 std::vector<tensorflow::DeviceAttributes> local_device_attributes;
542 grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
543 &local_device_attributes);
544
545 // This request make sure that we can create Rendezvous properly between
546 // Local and Remote context.
547 tensorflow::eager::CreateContextRequest base_request;
548 for (const auto& da : cluster_device_attributes) {
549 *base_request.add_cluster_device_attributes() = da;
550 }
551 for (const auto& da : local_device_attributes) {
552 *base_request.add_cluster_device_attributes() = da;
553 }
554
555 // Initialize remote eager workers.
556 if (reset_context) {
557 const tensorflow::Status s = CreateRemoteContexts(
558 context, remote_workers, context_id, context_view_id, keep_alive_secs,
559 server_def, remote_eager_workers.get(), context->Executor().Async(),
560 base_request);
561 // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
562 // the CreateRemoteContexts to fail. We currently only log instead of
563 // directly returning the error, since returning here will cause the server
564 // object to be destroyed (which currently CHECK-fails). The client will
565 // see additional errors if ops are subsequently sent to the failed workers.
566 if (TF_PREDICT_FALSE(!s.ok())) {
567 LOG(ERROR) << "Error when creating contexts on remote targets: "
568 << s.error_message()
569 << "\nExecuting remote ops or functions on these remote "
570 "targets will fail.";
571 }
572 } else {
573 if (sg.ok()) {
574 // Create remote contexts on the newly added workers only if the master
575 // has collected all device information from them (i.e., the
576 // GetAllRemoteDevices call returns succussfully). Note that in rare cases
577 // GetAllRemoteDevices can still fail even with RPCs configured to wait
578 // until the remote workers to become alive. If the master creates remote
579 // contexts on the workers whose devices are still not collected, those
580 // workers will be treated as existing workers subsequently, so the master
581 // will never get devices from them even with retrying UpdateServerDef.
582 sg.Update(CreateRemoteContexts(
583 context, added_workers, context_id, context_view_id + 1,
584 keep_alive_secs, server_def, remote_eager_workers.get(),
585 context->Executor().Async(), base_request));
586 }
587 if (!existing_workers.empty()) {
588 if (VLOG_IS_ON(1)) {
589 for (const string& w : existing_workers) {
590 VLOG(1) << "Updating cluster with existing worker " << w;
591 }
592 }
593 // The master's context_view_id will be incremented by one in the
594 // UpdateRemoteMaster call later. We want existing workers to also have
595 // the updated context_view_id, so we must set their context_view_id to
596 // the master's current context_view_id + 1.
597 sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers,
598 removed_workers, context_id,
599 context_view_id + 1, server_def,
600 remote_eager_workers.get(), base_request));
601 }
602 }
603
604 auto session_name = tensorflow::strings::StrCat("eager_", context_id);
605 if (reset_context) {
606 tensorflow::RemoteRendezvous* r =
607 grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
608 auto* device_mgr = grpc_server->worker_env()->device_mgr;
609 std::shared_ptr<tensorflow::WorkerSession> worker_session;
610 LOG_AND_RETURN_IF_ERROR(
611 grpc_server->worker_env()->session_mgr->CreateSession(
612 session_name, server_def, base_request.cluster_device_attributes(),
613 true));
614 LOG_AND_RETURN_IF_ERROR(
615 grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
616 session_name, &worker_session));
617
618 // Initialize remote tensor communication based on worker session.
619 LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
620
621 tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
622 tensorflow::eager::CreateClusterFLR(context_id, context,
623 worker_session.get());
624 auto remote_mgr = std::make_unique<tensorflow::eager::RemoteMgr>(
625 /*is_master=*/true, context);
626
627 LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
628 std::move(new_server), grpc_server->worker_env(), worker_session,
629 std::move(remote_eager_workers), std::move(new_remote_device_mgr),
630 remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
631 std::move(remote_mgr)));
632
633 // NOTE: We start the server after all other initialization, because the
634 // GrpcServer cannot be destroyed after it is started.
635 LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
636 } else {
637 sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
638 session_name, server_def, base_request.cluster_device_attributes(),
639 /*isolate_session_state=*/true));
640 sg.Update(context->UpdateRemoteMaster(context_id,
641 std::move(remote_eager_workers),
642 added_workers, removed_workers));
643 LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
644 }
645 #undef LOG_AND_RETURN_IF_ERROR
646
647 return tensorflow::Status::OK();
648 }
649 } // namespace
650
SetOrUpdateServerDef(const ServerDef & server_def,bool reset_context,int keep_alive_secs)651 Status EagerContextDistributedManager::SetOrUpdateServerDef(
652 const ServerDef& server_def, bool reset_context, int keep_alive_secs) {
653 if (server_def.has_cluster_device_filters()) {
654 if (reset_context) {
655 const auto& cdf = server_def.cluster_device_filters();
656 for (const auto& jdf : cdf.jobs()) {
657 const string remote_prefix = "/job:" + jdf.name() + "/task:";
658 for (const auto& tdf : jdf.tasks()) {
659 const int32_t task_index = tdf.first;
660 std::vector<string> device_filters(tdf.second.device_filters_size());
661 for (int i = 0; i < tdf.second.device_filters_size(); i++) {
662 device_filters[i] = tdf.second.device_filters(i);
663 }
664 const string remote_worker =
665 strings::StrCat(remote_prefix, task_index);
666 TF_RETURN_IF_ERROR(
667 context_->SetRemoteDeviceFilters(remote_worker, device_filters));
668 }
669 }
670 } else {
671 LOG(WARNING) << "Device filters can only be specified when initializing "
672 "the cluster. Any changes in device filters are ignored "
673 "when updating the server def.";
674 }
675 }
676 return UpdateContextWithServerDef(context_, server_def, reset_context,
677 keep_alive_secs);
678 }
679
EnableCollectiveOps(const ServerDef & server_def)680 Status EagerContextDistributedManager::EnableCollectiveOps(
681 const ServerDef& server_def) {
682 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
683 // server object (which currently CHECK-fails) and we miss the error, instead,
684 // we log the error, and then return to allow the user to see the error
685 // message.
686 #define LOG_AND_RETURN_IF_ERROR(...) \
687 do { \
688 const ::tensorflow::Status _status = (__VA_ARGS__); \
689 if (TF_PREDICT_FALSE(!_status.ok())) { \
690 LOG(ERROR) << _status.error_message(); \
691 return _status; \
692 } \
693 } while (0);
694
695 tensorflow::GrpcServer* grpc_server =
696 dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer());
697 if (grpc_server == nullptr) {
698 std::unique_ptr<tensorflow::ServerInterface> new_server;
699 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
700 grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
701 if (grpc_server == nullptr) {
702 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
703 "Currently, TF eager runtime only supports tensorflow::GrpcServer."));
704 }
705 LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
706
707 LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
708 std::move(new_server), grpc_server->worker_env()->device_mgr,
709 grpc_server->worker_env()->collective_executor_mgr.get()));
710 } else {
711 LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
712 LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
713 /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
714 grpc_server->worker_env()->collective_executor_mgr.get()));
715 }
716 #undef LOG_AND_RETURN_IF_ERROR
717 return Status::OK();
718 }
719
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)720 Status EagerContextDistributedManager::CheckRemoteAlive(
721 const std::string& remote_task_name, bool* is_alive) {
722 *is_alive = false;
723 GrpcServer* grpc_server = dynamic_cast<GrpcServer*>(context_->GetServer());
724 if (grpc_server == nullptr) {
725 return errors::Internal("Failed to get eager-compatible server instance.");
726 }
727 WorkerInterface* wi =
728 grpc_server->master_env()->worker_cache->GetOrCreateWorker(
729 remote_task_name);
730 if (wi == nullptr) {
731 return errors::InvalidArgument(
732 "Unable to find worker interface corresponding to task ",
733 remote_task_name);
734 }
735
736 GetStatusRequest request;
737 GetStatusResponse response;
738 Status remote_status;
739 Notification done;
740 wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
741 [&remote_status, &done](const Status& s) {
742 remote_status = s;
743 done.Notify();
744 });
745 done.WaitForNotification();
746
747 if (remote_status.ok()) {
748 *is_alive = true;
749 }
750 LOG(INFO) << "Remote worker " << remote_task_name
751 << " is not alive: " << remote_status.error_message();
752 return Status::OK();
753 }
754 #endif // !IS_MOBILE_PLATFORM
755 } // namespace tensorflow
756