• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
17 
18 #include "tensorflow/core/common_runtime/copy_tensor.h"
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/framework/node_def_util.h"
25 #include "tensorflow/core/framework/rendezvous.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/blocking_counter.h"
28 #include "tensorflow/core/platform/casts.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/notification.h"
32 #include "tensorflow/core/platform/platform.h"
33 #include "tensorflow/core/platform/refcount.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/protobuf/device_filters.pb.h"
36 #include "tensorflow/core/protobuf/error_codes.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
41 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
42 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
43 #include "tensorflow/core/distributed_runtime/remote_device.h"
44 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
45 #include "tensorflow/core/distributed_runtime/server_lib.h"
46 #include "tensorflow/core/distributed_runtime/worker_env.h"
47 #include "tensorflow/core/distributed_runtime/worker_interface.h"
48 #endif  // !IS_MOBILE_PLATFORM
49 
50 namespace tensorflow {
51 #if !defined(IS_MOBILE_PLATFORM)
52 namespace {
AreLocalDevicesCompatible(const tensorflow::EagerContext * context,const tensorflow::ServerDef & server_def)53 bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
54                                const tensorflow::ServerDef& server_def) {
55   if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
56     return false;
57   }
58   return server_def.default_session_config().SerializeAsString() ==
59          context->session_options().config.SerializeAsString();
60 }
61 
AddRemoteDevicesToMgr(const std::vector<string> & added_remote_workers,tensorflow::WorkerCacheInterface * worker_cache,tensorflow::DynamicDeviceMgr * remote_device_mgr)62 tensorflow::Status AddRemoteDevicesToMgr(
63     const std::vector<string>& added_remote_workers,
64     tensorflow::WorkerCacheInterface* worker_cache,
65     tensorflow::DynamicDeviceMgr* remote_device_mgr) {
66   std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
67   tensorflow::mutex remote_devices_mu;
68   int num_added_workers = added_remote_workers.size();
69   tensorflow::BlockingCounter counter(num_added_workers);
70   std::vector<tensorflow::Status> statuses(num_added_workers);
71   for (int i = 0; i < num_added_workers; i++) {
72     tensorflow::NewRemoteDevices(
73         tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
74         [i, &statuses, &counter, &remote_devices, &remote_devices_mu](
75             const tensorflow::Status& s,
76             std::vector<tensorflow::Device*>* devices) {
77           statuses[i] = s;
78           if (s.ok()) {
79             tensorflow::mutex_lock l(remote_devices_mu);
80             for (tensorflow::Device* d : *devices) {
81               remote_devices.emplace_back(d);
82             }
83           }
84           counter.DecrementCount();
85         });
86   }
87   counter.Wait();
88   for (int i = 0; i < num_added_workers; i++) {
89     TF_RETURN_IF_ERROR(statuses[i]);
90   }
91 
92   TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
93   return tensorflow::Status::OK();
94 }
95 
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DynamicDeviceMgr> * device_mgr)96 tensorflow::Status GetAllRemoteDevices(
97     const std::vector<string>& remote_workers,
98     tensorflow::WorkerCacheInterface* worker_cache,
99     std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
100   auto remote_device_mgr = std::make_unique<tensorflow::DynamicDeviceMgr>();
101   TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
102                                            remote_device_mgr.get()));
103   *device_mgr = std::move(remote_device_mgr);
104   return tensorflow::Status::OK();
105 }
106 
RemoveRemoteDevicesFromMgr(const std::vector<string> & removed_remote_workers,tensorflow::DynamicDeviceMgr * remote_device_mgr)107 tensorflow::Status RemoveRemoteDevicesFromMgr(
108     const std::vector<string>& removed_remote_workers,
109     tensorflow::DynamicDeviceMgr* remote_device_mgr) {
110   const std::vector<tensorflow::Device*> remote_devices =
111       (remote_device_mgr->ListDevices());
112   std::vector<tensorflow::Device*> devices_to_remove;
113   for (tensorflow::Device* d : remote_devices) {
114     for (const string& remote_worker : removed_remote_workers) {
115       if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
116                                                           d->name())) {
117         devices_to_remove.emplace_back(d);
118         break;
119       }
120     }
121   }
122   TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
123   return tensorflow::Status::OK();
124 }
125 
ListRemoteWorkers(tensorflow::ServerInterface * server,const string & local_worker,std::vector<string> * remote_workers)126 tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
127                                      const string& local_worker,
128                                      std::vector<string>* remote_workers) {
129   tensorflow::GrpcServer* grpc_server =
130       dynamic_cast<tensorflow::GrpcServer*>(server);
131   if (grpc_server == nullptr) {
132     return tensorflow::errors::Internal(
133         "Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
134   }
135   grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
136   remote_workers->erase(
137       std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
138       remote_workers->end());
139   return tensorflow::Status::OK();
140 }
141 
DifferentiateWorkerLists(const std::vector<string> * current_list,const std::vector<string> * new_list,std::vector<string> * added,std::vector<string> * removed,std::vector<string> * existing)142 void DifferentiateWorkerLists(const std::vector<string>* current_list,
143                               const std::vector<string>* new_list,
144                               std::vector<string>* added,
145                               std::vector<string>* removed,
146                               std::vector<string>* existing) {
147   // Get STL set_difference and set_intersection with one list traversal.
148   // Similar to the set_difference library function, the input lists
149   // (`current_list` and `new_list`) must be sorted before calling the function.
150   added->resize(new_list->size());
151   removed->resize(current_list->size());
152   existing->resize(current_list->size());
153   std::vector<string>::const_iterator curr_it = current_list->begin();
154   std::vector<string>::const_iterator new_it = new_list->begin();
155   std::vector<string>::iterator added_it = added->begin();
156   std::vector<string>::iterator removed_it = removed->begin();
157   std::vector<string>::iterator existing_it = existing->begin();
158   while (curr_it != current_list->end() && new_it != new_list->end()) {
159     if (*curr_it < *new_it) {
160       *removed_it++ = *curr_it++;
161     } else if (*curr_it > *new_it) {
162       *added_it++ = *new_it++;
163     } else {
164       *existing_it++ = *curr_it++;
165       new_it++;
166     }
167   }
168   removed_it = std::copy(curr_it, current_list->end(), removed_it);
169   added_it = std::copy(new_it, new_list->end(), added_it);
170   added->resize(added_it - added->begin());
171   removed->resize(removed_it - removed->begin());
172   existing->resize(existing_it - existing->begin());
173 }
174 
GetReplacedFromExistingWorkers(const std::vector<string> * existing_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * client_cache,std::vector<string> * replaced_workers)175 tensorflow::Status GetReplacedFromExistingWorkers(
176     const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
177     tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
178     tensorflow::eager::EagerClientCache* client_cache,
179     std::vector<string>* replaced_workers) {
180   tensorflow::BlockingCounter counter(existing_workers->size());
181   std::vector<tensorflow::Status> statuses(existing_workers->size());
182   tensorflow::eager::KeepAliveRequest request;
183   request.set_context_id(context_id);
184   std::vector<tensorflow::eager::KeepAliveResponse> responses(
185       existing_workers->size());
186   for (int i = 0; i < existing_workers->size(); i++) {
187     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
188     statuses[i] =
189         client_cache->GetClient(existing_workers->at(i), &eager_client);
190     if (!statuses[i].ok()) {
191       counter.DecrementCount();
192       continue;
193     }
194     eager_client->KeepAliveAsync(
195         &request, &responses[i],
196         [i, &statuses, &counter](const tensorflow::Status& s) {
197           statuses[i] = s;
198           counter.DecrementCount();
199         });
200   }
201   counter.Wait();
202   for (int i = 0; i < existing_workers->size(); i++) {
203     // If the RPC fails (indicating that the requested ID doesn't exist on
204     // remote), or the returned view ID is not equal to the local one
205     // (indicating that the remote worker has a stale view of cluster), treat
206     // the worker as replaced.
207     if (!statuses[i].ok() ||
208         responses[i].context_view_id() != context_view_id) {
209       replaced_workers->emplace_back(existing_workers->at(i));
210     }
211   }
212   return tensorflow::Status::OK();
213 }
214 
CreateRemoteContexts(EagerContext * context,const std::vector<string> & remote_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,const tensorflow::eager::CreateContextRequest & base_request)215 tensorflow::Status CreateRemoteContexts(
216     EagerContext* context, const std::vector<string>& remote_workers,
217     tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
218     int keep_alive_secs, const tensorflow::ServerDef& server_def,
219     tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
220     const tensorflow::eager::CreateContextRequest& base_request) {
221   int num_remote_workers = remote_workers.size();
222   tensorflow::BlockingCounter counter(num_remote_workers);
223   std::vector<tensorflow::Status> statuses(num_remote_workers);
224   for (int i = 0; i < num_remote_workers; i++) {
225     const string& remote_worker = remote_workers[i];
226     tensorflow::DeviceNameUtils::ParsedName parsed_name;
227     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
228                                                     &parsed_name)) {
229       statuses[i] = tensorflow::errors::InvalidArgument(
230           "Unable to parse ", remote_worker, " as a device name");
231       counter.DecrementCount();
232       continue;
233     }
234 
235     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
236     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
237     if (eager_client == nullptr) {
238       statuses[i] = tensorflow::errors::Internal(
239           "Cannot find a client for the given target:", remote_worker);
240     }
241     if (!statuses[i].ok()) {
242       counter.DecrementCount();
243       continue;
244     }
245 
246     tensorflow::eager::CreateContextRequest request;
247     tensorflow::eager::CreateContextResponse* response =
248         new tensorflow::eager::CreateContextResponse();
249     request.set_context_id(context_id);
250     request.set_context_view_id(context_view_id);
251     *request.mutable_server_def() = server_def;
252     request.mutable_server_def()->set_job_name(parsed_name.job);
253     request.mutable_server_def()->set_task_index(parsed_name.task);
254     request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
255         server_def.default_session_config());
256 
257     std::vector<bool> filtered_device_mask;
258     context->FilterDevicesForRemoteWorkers(
259         remote_worker, base_request.cluster_device_attributes(),
260         &filtered_device_mask);
261     DCHECK_EQ(filtered_device_mask.size(),
262               base_request.cluster_device_attributes_size());
263     for (int i = 0; i < filtered_device_mask.size(); i++) {
264       if (filtered_device_mask[i]) {
265         const auto& da = base_request.cluster_device_attributes(i);
266         *request.add_cluster_device_attributes() = da;
267       }
268     }
269     request.set_async(async);
270     request.set_keep_alive_secs(keep_alive_secs);
271     // TODO(b/134094971): deprecate lazy_copy_remote_function_inputs when server
272     // doesn't try to get the value of lazy_copy_remote_function_inputs.
273     request.set_lazy_copy_remote_function_inputs(true);
274 
275     eager_client->CreateContextAsync(
276         &request, response,
277         [i, &statuses, &counter, response](const tensorflow::Status& s) {
278           statuses[i] = s;
279           delete response;
280           counter.DecrementCount();
281         });
282   }
283   counter.Wait();
284   tensorflow::StatusGroup sg;
285   for (int i = 0; i < num_remote_workers; i++) {
286     if (TF_PREDICT_FALSE(!statuses[i].ok())) {
287       sg.Update(statuses[i]);
288     }
289   }
290   return sg.as_summary_status();
291 }
292 
UpdateRemoteContexts(EagerContext * context,const std::vector<string> & remote_workers,const std::vector<string> & added_workers,const std::vector<string> & removed_workers,tensorflow::uint64 context_id,tensorflow::uint64 context_view_id,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,const tensorflow::eager::CreateContextRequest & base_request)293 tensorflow::Status UpdateRemoteContexts(
294     EagerContext* context, const std::vector<string>& remote_workers,
295     const std::vector<string>& added_workers,
296     const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
297     tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
298     tensorflow::eager::EagerClientCache* remote_eager_workers,
299     const tensorflow::eager::CreateContextRequest& base_request) {
300   int num_remote_workers = remote_workers.size();
301   tensorflow::BlockingCounter counter(num_remote_workers);
302   std::vector<tensorflow::Status> statuses(num_remote_workers);
303 
304   int cluster_device_count = base_request.cluster_device_attributes_size();
305   std::unordered_set<string> added_or_removed(added_workers.begin(),
306                                               added_workers.end());
307   std::copy(removed_workers.begin(), removed_workers.end(),
308             std::inserter(added_or_removed, added_or_removed.end()));
309   // Whether each device is in the updated (added or removed) workers
310   std::vector<bool> device_added_or_removed(cluster_device_count);
311   for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
312     const auto& da = base_request.cluster_device_attributes().at(i);
313     tensorflow::DeviceNameUtils::ParsedName pn;
314     tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
315     string task_name;
316     tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
317     if (added_or_removed.find(task_name) != added_or_removed.end()) {
318       device_added_or_removed[i] = true;
319     }
320   }
321 
322   for (int i = 0; i < num_remote_workers; i++) {
323     const string& remote_worker = remote_workers[i];
324     tensorflow::DeviceNameUtils::ParsedName parsed_name;
325     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
326                                                     &parsed_name)) {
327       statuses[i] = tensorflow::errors::InvalidArgument(
328           "Unable to parse ", remote_worker, " as a device name");
329       counter.DecrementCount();
330       continue;
331     }
332 
333     tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
334     statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
335     if (eager_client == nullptr) {
336       statuses[i] = tensorflow::errors::Internal(
337           "Cannot find a client for the given target:", remote_worker);
338     }
339     if (!statuses[i].ok()) {
340       counter.DecrementCount();
341       continue;
342     }
343 
344     std::vector<bool> filtered_device_mask;
345     context->FilterDevicesForRemoteWorkers(
346         remote_worker, base_request.cluster_device_attributes(),
347         &filtered_device_mask);
348     DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
349 
350     // If any of the devices that match the device filters are in the set of
351     // added or removed workers, we must send a complete UpdateContextRequest.
352     // Otherwise, only send a simple request to increment context view ID.
353     std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
354     std::transform(device_added_or_removed.begin(),
355                    device_added_or_removed.end(), filtered_device_mask.begin(),
356                    added_or_removed_filtered_devices.begin(),
357                    std::logical_and<bool>());
358     const bool full_update_request =
359         std::accumulate(added_or_removed_filtered_devices.begin(),
360                         added_or_removed_filtered_devices.end(), false,
361                         std::logical_or<bool>());
362 
363     tensorflow::eager::UpdateContextRequest request;
364     auto* response = new tensorflow::eager::UpdateContextResponse();
365     request.set_context_id(context_id);
366     request.set_context_view_id(context_view_id);
367     if (full_update_request) {
368       *request.mutable_server_def() = server_def;
369       request.mutable_server_def()->set_job_name(parsed_name.job);
370       request.mutable_server_def()->set_task_index(parsed_name.task);
371       request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
372           server_def.default_session_config());
373       for (int i = 0; i < cluster_device_count; i++) {
374         if (filtered_device_mask[i]) {
375           const auto& da = base_request.cluster_device_attributes(i);
376           *request.add_cluster_device_attributes() = da;
377         }
378       }
379     }
380 
381     eager_client->UpdateContextAsync(
382         &request, response,
383         [i, &statuses, &counter, response](const tensorflow::Status& s) {
384           statuses[i] = s;
385           delete response;
386           counter.DecrementCount();
387         });
388   }
389   counter.Wait();
390   for (int i = 0; i < num_remote_workers; i++) {
391     TF_RETURN_IF_ERROR(statuses[i]);
392   }
393   return tensorflow::Status::OK();
394 }
395 
UpdateContextWithServerDef(EagerContext * context,const tensorflow::ServerDef & server_def,bool reset_context,int keep_alive_secs)396 tensorflow::Status UpdateContextWithServerDef(
397     EagerContext* context, const tensorflow::ServerDef& server_def,
398     bool reset_context, int keep_alive_secs) {
399   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
400   // server object (which currently CHECK-fails) and we miss the error, instead,
401   // we log the error, and then return to allow the user to see the error
402   // message.
403 #define LOG_AND_RETURN_IF_ERROR(...)                    \
404   do {                                                  \
405     const ::tensorflow::Status _status = (__VA_ARGS__); \
406     if (TF_PREDICT_FALSE(!_status.ok())) {              \
407       LOG(ERROR) << _status.error_message();            \
408       return _status;                                   \
409     }                                                   \
410   } while (0);
411 
412   string worker_name =
413       tensorflow::strings::StrCat("/job:", server_def.job_name(),
414                                   "/replica:0/task:", server_def.task_index());
415 
416   // List of current remote workers before updating server_def. Unused if
417   // resetting the server_def.
418   std::vector<string> curr_remote_workers;
419   // List of updated remote workers.
420   std::vector<string> remote_workers;
421 
422   // New server created for new server_def. Unused if updating server_def.
423   std::unique_ptr<tensorflow::ServerInterface> new_server;
424   tensorflow::GrpcServer* grpc_server;
425   if (reset_context) {
426     const tensorflow::DeviceMgr* device_mgr =
427         AreLocalDevicesCompatible(context, server_def)
428             ? context->local_device_mgr()
429             : nullptr;
430     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
431         server_def, {device_mgr}, &new_server));
432     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
433     LOG_AND_RETURN_IF_ERROR(
434         ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
435   } else {
436     LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
437                                               &curr_remote_workers));
438     // No need to check the cast here, since `ListRemoteWorkers` already checks
439     // if the server is a GRPC server or not.
440     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
441     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
442     LOG_AND_RETURN_IF_ERROR(
443         ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
444   }
445 
446   tensorflow::uint64 context_id = context->GetContextId();
447   tensorflow::uint64 context_view_id = context->GetContextViewId();
448   if (reset_context) {
449     context_id = tensorflow::EagerContext::NewContextId();
450     context_view_id = 0;
451     // Make master eager context accessible by local eager service, which might
452     // receive send tensor requests from remote workers.
453     LOG_AND_RETURN_IF_ERROR(
454         grpc_server->AddMasterEagerContextToEagerService(context_id, context));
455   }
456 
457   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
458   LOG_AND_RETURN_IF_ERROR(
459       grpc_server->master_env()->worker_cache->GetEagerClientCache(
460           &remote_eager_workers));
461 
462   // For cluster update, use a status group to aggregate statuses from
463   //   * adding and removing remote devices
464   //   * creating remote contexts on newly added workers
465   //   * updating remote contexts on existing workers
466   //   * updating the master context
467   // Note that we should not return immediately on errors in the middle of these
468   // updates to prevent cluster from having inconsistent context views.
469   //
470   // Unused if `reset_context` is True.
471   tensorflow::StatusGroup sg;
472 
473   // When updating an existing context, populate the following lists with:
474   // * added_workers: set(remote_workers) - set(curr_remote_workers)
475   // * removed_workers: set(curr_remote_workers) - set(remote_workers)
476   // * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
477   // * replaced_workers: workers with the same task names and potentially the
478   //     same `hostname:port`s, but replaced by different processes
479   std::vector<string> added_workers;
480   std::vector<string> removed_workers;
481   std::vector<string> existing_workers;
482   std::vector<string> replaced_workers;
483 
484   // New remote device manager created for new server_def. Unused if updating
485   // server_def.
486   std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
487   tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
488   if (reset_context) {
489     LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
490         remote_workers, grpc_server->master_env()->worker_cache,
491         &new_remote_device_mgr));
492     remote_device_mgr = new_remote_device_mgr.get();
493   } else {
494     context->ClearCachesAndDefaultExecutor();
495     // TODO(b/143914772): Potential memory leak if rendezvous has pending
496     // tensors for removed / replaced workers.
497 
498     remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
499     if (remote_device_mgr == nullptr) {
500       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
501           "Updating context with an invalid set of remote devices."));
502     }
503     std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
504     std::sort(remote_workers.begin(), remote_workers.end());
505     DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
506                              &added_workers, &removed_workers,
507                              &existing_workers);
508     sg.Update(GetReplacedFromExistingWorkers(
509         &existing_workers, context_id, context->GetContextViewId(), server_def,
510         remote_eager_workers.get(), &replaced_workers));
511     if (VLOG_IS_ON(1)) {
512       VLOG(1) << "Updating cluster with following changes";
513       for (const string& w : added_workers) VLOG(1) << "  Added worker " << w;
514       for (const string& w : removed_workers)
515         VLOG(1) << "  Removed worker " << w;
516       for (const string& w : replaced_workers)
517         VLOG(1) << "  Replaced worker " << w;
518     }
519     if (!replaced_workers.empty()) {
520       // Treat replaced workers as removed then added back, so that we recreate
521       // remote devices and contexts, and re-register functions on those workers
522       removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
523                              replaced_workers.end());
524       added_workers.insert(added_workers.end(), replaced_workers.begin(),
525                            replaced_workers.end());
526       for (const string& w : replaced_workers) {
527         existing_workers.erase(
528             std::remove(existing_workers.begin(), existing_workers.end(), w),
529             existing_workers.end());
530       }
531     }
532     sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
533     sg.Update(AddRemoteDevicesToMgr(added_workers,
534                                     grpc_server->master_env()->worker_cache,
535                                     remote_device_mgr));
536   }
537 
538   std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
539   remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
540 
541   std::vector<tensorflow::DeviceAttributes> local_device_attributes;
542   grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
543       &local_device_attributes);
544 
545   // This request make sure that we can create Rendezvous properly between
546   // Local and Remote context.
547   tensorflow::eager::CreateContextRequest base_request;
548   for (const auto& da : cluster_device_attributes) {
549     *base_request.add_cluster_device_attributes() = da;
550   }
551   for (const auto& da : local_device_attributes) {
552     *base_request.add_cluster_device_attributes() = da;
553   }
554 
555   // Initialize remote eager workers.
556   if (reset_context) {
557     const tensorflow::Status s = CreateRemoteContexts(
558         context, remote_workers, context_id, context_view_id, keep_alive_secs,
559         server_def, remote_eager_workers.get(), context->Executor().Async(),
560         base_request);
561     // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
562     // the CreateRemoteContexts to fail. We currently only log instead of
563     // directly returning the error, since returning here will cause the server
564     // object to be destroyed (which currently CHECK-fails). The client will
565     // see additional errors if ops are subsequently sent to the failed workers.
566     if (TF_PREDICT_FALSE(!s.ok())) {
567       LOG(ERROR) << "Error when creating contexts on remote targets: "
568                  << s.error_message()
569                  << "\nExecuting remote ops or functions on these remote "
570                     "targets will fail.";
571     }
572   } else {
573     if (sg.ok()) {
574       // Create remote contexts on the newly added workers only if the master
575       // has collected all device information from them (i.e., the
576       // GetAllRemoteDevices call returns succussfully). Note that in rare cases
577       // GetAllRemoteDevices can still fail even with RPCs configured to wait
578       // until the remote workers to become alive. If the master creates remote
579       // contexts on the workers whose devices are still not collected, those
580       // workers will be treated as existing workers subsequently, so the master
581       // will never get devices from them even with retrying UpdateServerDef.
582       sg.Update(CreateRemoteContexts(
583           context, added_workers, context_id, context_view_id + 1,
584           keep_alive_secs, server_def, remote_eager_workers.get(),
585           context->Executor().Async(), base_request));
586     }
587     if (!existing_workers.empty()) {
588       if (VLOG_IS_ON(1)) {
589         for (const string& w : existing_workers) {
590           VLOG(1) << "Updating cluster with existing worker " << w;
591         }
592       }
593       // The master's context_view_id will be incremented by one in the
594       // UpdateRemoteMaster call later. We want existing workers to also have
595       // the updated context_view_id, so we must set their context_view_id to
596       // the master's current context_view_id + 1.
597       sg.Update(UpdateRemoteContexts(context, existing_workers, added_workers,
598                                      removed_workers, context_id,
599                                      context_view_id + 1, server_def,
600                                      remote_eager_workers.get(), base_request));
601     }
602   }
603 
604   auto session_name = tensorflow::strings::StrCat("eager_", context_id);
605   if (reset_context) {
606     tensorflow::RemoteRendezvous* r =
607         grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
608     auto* device_mgr = grpc_server->worker_env()->device_mgr;
609     std::shared_ptr<tensorflow::WorkerSession> worker_session;
610     LOG_AND_RETURN_IF_ERROR(
611         grpc_server->worker_env()->session_mgr->CreateSession(
612             session_name, server_def, base_request.cluster_device_attributes(),
613             true));
614     LOG_AND_RETURN_IF_ERROR(
615         grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
616             session_name, &worker_session));
617 
618     // Initialize remote tensor communication based on worker session.
619     LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
620 
621     tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
622         tensorflow::eager::CreateClusterFLR(context_id, context,
623                                             worker_session.get());
624     auto remote_mgr = std::make_unique<tensorflow::eager::RemoteMgr>(
625         /*is_master=*/true, context);
626 
627     LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
628         std::move(new_server), grpc_server->worker_env(), worker_session,
629         std::move(remote_eager_workers), std::move(new_remote_device_mgr),
630         remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
631         std::move(remote_mgr)));
632 
633     // NOTE: We start the server after all other initialization, because the
634     // GrpcServer cannot be destroyed after it is started.
635     LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
636   } else {
637     sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
638         session_name, server_def, base_request.cluster_device_attributes(),
639         /*isolate_session_state=*/true));
640     sg.Update(context->UpdateRemoteMaster(context_id,
641                                           std::move(remote_eager_workers),
642                                           added_workers, removed_workers));
643     LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
644   }
645 #undef LOG_AND_RETURN_IF_ERROR
646 
647   return tensorflow::Status::OK();
648 }
649 }  // namespace
650 
SetOrUpdateServerDef(const ServerDef & server_def,bool reset_context,int keep_alive_secs)651 Status EagerContextDistributedManager::SetOrUpdateServerDef(
652     const ServerDef& server_def, bool reset_context, int keep_alive_secs) {
653   if (server_def.has_cluster_device_filters()) {
654     if (reset_context) {
655       const auto& cdf = server_def.cluster_device_filters();
656       for (const auto& jdf : cdf.jobs()) {
657         const string remote_prefix = "/job:" + jdf.name() + "/task:";
658         for (const auto& tdf : jdf.tasks()) {
659           const int32_t task_index = tdf.first;
660           std::vector<string> device_filters(tdf.second.device_filters_size());
661           for (int i = 0; i < tdf.second.device_filters_size(); i++) {
662             device_filters[i] = tdf.second.device_filters(i);
663           }
664           const string remote_worker =
665               strings::StrCat(remote_prefix, task_index);
666           TF_RETURN_IF_ERROR(
667               context_->SetRemoteDeviceFilters(remote_worker, device_filters));
668         }
669       }
670     } else {
671       LOG(WARNING) << "Device filters can only be specified when initializing "
672                       "the cluster. Any changes in device filters are ignored "
673                       "when updating the server def.";
674     }
675   }
676   return UpdateContextWithServerDef(context_, server_def, reset_context,
677                                     keep_alive_secs);
678 }
679 
EnableCollectiveOps(const ServerDef & server_def)680 Status EagerContextDistributedManager::EnableCollectiveOps(
681     const ServerDef& server_def) {
682   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
683   // server object (which currently CHECK-fails) and we miss the error, instead,
684   // we log the error, and then return to allow the user to see the error
685   // message.
686 #define LOG_AND_RETURN_IF_ERROR(...)                    \
687   do {                                                  \
688     const ::tensorflow::Status _status = (__VA_ARGS__); \
689     if (TF_PREDICT_FALSE(!_status.ok())) {              \
690       LOG(ERROR) << _status.error_message();            \
691       return _status;                                   \
692     }                                                   \
693   } while (0);
694 
695   tensorflow::GrpcServer* grpc_server =
696       dynamic_cast<tensorflow::GrpcServer*>(context_->GetServer());
697   if (grpc_server == nullptr) {
698     std::unique_ptr<tensorflow::ServerInterface> new_server;
699     LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
700     grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
701     if (grpc_server == nullptr) {
702       LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
703           "Currently, TF eager runtime only supports tensorflow::GrpcServer."));
704     }
705     LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
706 
707     LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
708         std::move(new_server), grpc_server->worker_env()->device_mgr,
709         grpc_server->worker_env()->collective_executor_mgr.get()));
710   } else {
711     LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
712     LOG_AND_RETURN_IF_ERROR(context_->StoreCollectiveOpsServer(
713         /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
714         grpc_server->worker_env()->collective_executor_mgr.get()));
715   }
716 #undef LOG_AND_RETURN_IF_ERROR
717   return Status::OK();
718 }
719 
CheckRemoteAlive(const std::string & remote_task_name,bool * is_alive)720 Status EagerContextDistributedManager::CheckRemoteAlive(
721     const std::string& remote_task_name, bool* is_alive) {
722   *is_alive = false;
723   GrpcServer* grpc_server = dynamic_cast<GrpcServer*>(context_->GetServer());
724   if (grpc_server == nullptr) {
725     return errors::Internal("Failed to get eager-compatible server instance.");
726   }
727   WorkerInterface* wi =
728       grpc_server->master_env()->worker_cache->GetOrCreateWorker(
729           remote_task_name);
730   if (wi == nullptr) {
731     return errors::InvalidArgument(
732         "Unable to find worker interface corresponding to task ",
733         remote_task_name);
734   }
735 
736   GetStatusRequest request;
737   GetStatusResponse response;
738   Status remote_status;
739   Notification done;
740   wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
741                      [&remote_status, &done](const Status& s) {
742                        remote_status = s;
743                        done.Notify();
744                      });
745   done.WaitForNotification();
746 
747   if (remote_status.ok()) {
748     *is_alive = true;
749   }
750   LOG(INFO) << "Remote worker " << remote_task_name
751             << " is not alive: " << remote_status.error_message();
752   return Status::OK();
753 }
754 #endif  // !IS_MOBILE_PLATFORM
755 }  // namespace tensorflow
756