• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Master implements the service MasterSerivce.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31 
32 #include "tensorflow/core/distributed_runtime/master.h"
33 
34 #include <unordered_set>
35 #include <vector>
36 
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/protobuf/cluster.pb.h"
53 #include "tensorflow/core/protobuf/master.pb.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57 
58 namespace tensorflow {
59 
60 namespace {
61 const char* const kGrpcProtocol = "grpc://";
62 }  // namespace
63 
Master(MasterEnv * env,double session_gc_seconds)64 Master::Master(MasterEnv* env, double session_gc_seconds)
65     : env_(env),
66       last_1000_steps_(1000),
67       step_count_(0),
68       session_gc_seconds_(session_gc_seconds),
69       recent_request_ids_(10000) {
70   // Right now, a master service must be co-located with a device.
71   // Otherwise, fetches do not work.
72   CHECK(!env->local_devices.empty());
73 
74   if (session_gc_seconds_ > 0.0) {
75     gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
76                                         [this]() { GC(); });
77   } else {
78     gc_thread_ = nullptr;
79   }
80 }
81 
~Master()82 Master::~Master() {
83   if (gc_thread_) {
84     mutex_lock l(mu_);
85     shutdown_ = true;
86     shutdown_cv_.notify_all();
87     delete gc_thread_;
88   }
89 }
90 
GC()91 void Master::GC() {
92   Env* env = Env::Default();
93   while (true) {
94     mutex_lock l(mu_);
95     const int kTimeoutMilliseconds = 10 * 1000;  // 10 seconds.
96     WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
97     if (shutdown_) {
98       break;
99     }
100     std::vector<string> handles;
101     const int64 num_micros = static_cast<int64>(session_gc_seconds_ * 1000000);
102     for (const auto& entry : sessions_) {
103       int64 lat = entry.second->last_access_time_usec();
104       if (static_cast<int64>(env->NowMicros()) - lat > num_micros) {
105         handles.push_back(entry.first);
106         auto* sess = entry.second;
107         SchedClosure([this, sess]() {
108           LOG(WARNING) << "GC session " << sess->handle() << " after "
109                        << session_gc_seconds_ << " seconds.  "
110                        << "Note that if you are starting multiple replicas "
111                        << "on a staggered delay, session_gc_seconds may need "
112                        << "to be raised.";
113           sess->GarbageCollect();
114         });
115       }
116     }
117     for (const auto& handle : handles) sessions_.erase(handle);
118   }
119 }
120 
FindMasterSession(const string & handle)121 MasterSession* Master::FindMasterSession(const string& handle) {
122   MasterSession* session = nullptr;
123   {
124     mutex_lock l(mu_);
125     session = gtl::FindPtrOrNull(sessions_, handle);
126     if (session != nullptr) {
127       session->Ref();
128     }
129   }
130   return session;
131 }
132 
133 class DeviceFinder {
134  public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)135   static Status GetRemoteDevices(
136       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
137       WorkerCacheInterface* worker_cache,
138       std::vector<std::unique_ptr<Device>>* out_remote) {
139     DeviceFinder finder(device_filters, env, worker_cache);
140     finder.Start();
141     TF_RETURN_IF_ERROR(finder.Wait());
142     finder.GetRemoteDevices(env->local_devices, out_remote);
143     return Status::OK();
144   }
145 
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)146   static void GetRemoteWorkers(
147       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
148       WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
149     DeviceFinder finder(device_filters, env, worker_cache);
150     *workers = finder.targets_;
151   }
152 
153  private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)154   explicit DeviceFinder(
155       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
156       WorkerCacheInterface* worker_cache)
157       : env_(env), worker_cache_(worker_cache) {
158     CHECK(worker_cache) << "Worker cache was null!";
159     auto process_filter = [this](const string& filter) {
160       DeviceNameUtils::ParsedName parsed;
161       if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
162         filters_.push_back(parsed);
163       } else {
164         LOG(FATAL) << "Skipping invalid filter: " << filter;
165       }
166     };
167     for (const string& filter : device_filters) {
168       process_filter(filter);
169     }
170     // Enumerates all known workers' target. A target name is a
171     // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
172     if (filters_.empty()) {
173       // If no filters were specified, we list all known workers in
174       // `worker_cache`.
175       std::vector<string> workers;
176       worker_cache->ListWorkers(&workers);
177       std::swap(workers, targets_);
178     } else {
179       // When applying filters, we must include the local worker, even if it
180       // does not match any of the filters.
181       CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
182       const string& local_device_name = env_->local_devices[0]->name();
183       DeviceNameUtils::ParsedName local_parsed_name;
184       CHECK(DeviceNameUtils::ParseFullName(local_device_name,
185                                            &local_parsed_name));
186       bool all_filters_have_job = true;
187       std::unordered_set<string> filter_job_names({local_parsed_name.job});
188       for (const DeviceNameUtils::ParsedName& filter : filters_) {
189         all_filters_have_job = all_filters_have_job && filter.has_job;
190         if (filter.has_job) {
191           filter_job_names.insert(filter.job);
192         }
193       }
194 
195       std::vector<string> workers;
196       if (all_filters_have_job) {
197         // If all of the device filters have a job specified, then we only need
198         // to list the workers in the jobs named in the filter, because a worker
199         // in any other job would not match any filter.
200         for (const string& job_name : filter_job_names) {
201           VLOG(2) << "Selectively listing workers in job: " << job_name;
202           std::vector<string> workers_in_job;
203           worker_cache->ListWorkersInJob(job_name, &workers_in_job);
204           workers.insert(workers.end(), workers_in_job.begin(),
205                          workers_in_job.end());
206         }
207       } else {
208         // If any of the device filters does not have a job specified, then we
209         // must list the workers from all jobs.
210         VLOG(2) << "Listing workers in all jobs because some device "
211                 << "filter has no job specified. Filters were:";
212         if (device_filters.empty()) {
213           VLOG(2) << "- <NO FILTERS>";
214         } else {
215           for (const string& filter : device_filters) {
216             VLOG(2) << "- " << filter;
217           }
218         }
219         worker_cache->ListWorkers(&workers);
220       }
221       for (const string& name : workers) {
222         if (MatchFilters(name) ||
223             DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
224           targets_.push_back(name);
225         }
226       }
227     }
228     seen_targets_.assign(targets_.size(), false);
229   }
230 
~DeviceFinder()231   ~DeviceFinder() {
232     for (Device* dev : found_) delete dev;
233   }
234 
Start()235   void Start() {
236     {
237       mutex_lock l(mu_);
238       num_pending_ = targets_.size();
239       if (num_pending_ == 0) {
240         pending_zero_.notify_all();
241       }
242     }
243     // Talk to all workers to get the list of available devices.
244     using std::placeholders::_1;
245     using std::placeholders::_2;
246     for (size_t i = 0; i < targets_.size(); ++i) {
247       // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
248       // never be called.
249       NewRemoteDevices(env_->env, worker_cache_, targets_[i],
250                        std::bind(&ME::WhenFound, this, i, _1, _2));
251     }
252   }
253 
254   // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
255   // to hear from workers, log a list of the workers who have not
256   // responded.
257   const int32 kLoggingPeriodMs = 10 * 1000;
258 
Wait()259   Status Wait() {
260     mutex_lock l(mu_);
261     // TODO(mrry): Propagate a timeout here, since `num_pending_` may
262     // never become zero.
263     while (num_pending_ != 0) {
264       pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
265       if (num_pending_ != 0) {
266         for (size_t i = 0; i < targets_.size(); ++i) {
267           if (!seen_targets_[i]) {
268             LOG(INFO)
269                 << "CreateSession still waiting for response from worker: "
270                 << targets_[i];
271           }
272         }
273       }
274     }
275     return status_;
276   }
277 
278   // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)279   void GetRemoteDevices(const std::vector<Device*>& local,
280                         std::vector<std::unique_ptr<Device>>* remote) {
281     std::unordered_set<string> names(local.size());
282     for (Device* dev : local) names.insert(dev->name());
283     mutex_lock l(mu_);
284     for (Device* dev : found_) {
285       const string& name = dev->name();
286       if (names.insert(name).second && MatchFilters(name)) {
287         remote->push_back(std::unique_ptr<Device>(dev));
288       } else {
289         delete dev;
290       }
291     }
292     found_.clear();
293   }
294 
295   typedef DeviceFinder ME;
296   const MasterEnv* env_;
297   WorkerCacheInterface* worker_cache_;
298   std::vector<DeviceNameUtils::ParsedName> filters_;
299 
300   mutex mu_;
301   int num_pending_ GUARDED_BY(mu_);
302   condition_variable pending_zero_;
303   std::vector<Device*> found_ GUARDED_BY(mu_);
304   // List of targets to be contacted by this DeviceFinder. The
305   // respective `bool` in `seen_targets_` indicates whether we have
306   // heard from this target or not.
307   std::vector<string> targets_;
308   std::vector<bool> seen_targets_ GUARDED_BY(mu_);
309   Status status_;
310 
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)311   void WhenFound(int target_index, const Status& s,
312                  std::vector<Device*>* devices) {
313     mutex_lock l(mu_);
314     seen_targets_[target_index] = true;
315     if (!s.ok()) {
316       LOG(ERROR) << "CreateSession failed because worker "
317                  << targets_[target_index] << " returned error: " << s;
318       status_.Update(s);
319     } else {
320       found_.insert(found_.end(), devices->begin(), devices->end());
321       devices->clear();
322     }
323     --num_pending_;
324     if (num_pending_ == 0) {
325       pending_zero_.notify_all();
326     }
327   }
328 
329   // Returns true iff the set of devices allowed by 'x' intersects
330   // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)331   bool Intersects(const DeviceNameUtils::ParsedName& x,
332                   const DeviceNameUtils::ParsedName& y) {
333     return (!x.has_job || !y.has_job || x.job == y.job) &&
334            (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
335            (!x.has_task || !y.has_task || x.task == y.task) &&
336            (!x.has_type || !y.has_type || x.type == y.type) &&
337            (!x.has_id || !y.has_id || x.id == y.id);
338   }
339 
340   // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)341   bool MatchFilters(const string& name) {
342     if (filters_.empty()) return true;
343     DeviceNameUtils::ParsedName x;
344     if (DeviceNameUtils::ParseFullName(name, &x)) {
345       for (const auto& filter : filters_) {
346         if (Intersects(x, filter)) return true;
347       }
348     }
349     return false;
350   }
351 
352   TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
353 };
354 
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)355 void Master::CreateSession(const CreateSessionRequest* req,
356                            CreateSessionResponse* resp, MyClosure done) {
357   SchedClosure([this, req, resp, done]() {
358     Status status;
359     WorkerCacheFactoryOptions worker_cache_factory_options;
360     string grpc_protocol("grpc");
361     worker_cache_factory_options.protocol = &grpc_protocol;
362     auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
363     status = ValidateExternalGraphDefSyntax(req->graph_def());
364     if (!status.ok()) return;
365 
366     // The following 4 variables are set differently, depending on whether this
367     // session uses a client-provided clusterspec or not.
368     WorkerCacheInterface* worker_cache = nullptr;
369     // Note: worker_cache_ptr will be null except if this session is using a
370     // client-supplied ClusterDef (ClusterSpec propagation).
371     std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
372     std::unique_ptr<DeviceSet> device_set;
373     // TODO(saeta): Convert to std::make_unique when available.
374     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
375         new std::vector<std::unique_ptr<Device>>());
376 
377     if (req->config().has_cluster_def()) {
378       worker_cache_factory_options.cluster_def = &req->config().cluster_def();
379 
380       // Set the server_def's job_name and task_index fields.
381       string normalized_string;
382       string grpc_protocol(kGrpcProtocol);
383       if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
384           0) {
385         normalized_string =
386             req->target().substr(grpc_protocol.length(), string::npos);
387       } else {
388         normalized_string = req->target();
389       }
390       for (auto&& job : req->config().cluster_def().job()) {
391         for (auto&& task : job.tasks()) {
392           if (task.second == normalized_string) {
393             if (worker_cache_factory_options.job_name != nullptr) {
394               status = errors::InvalidArgument(
395                   "Found multiple matching tasks that correspond to "
396                   "to the master. Master target: '",
397                   req->target(), "'. ClusterDef: ",
398                   req->config().cluster_def().ShortDebugString());
399               LOG(ERROR) << status;
400               return;
401             }
402             if (env_->local_devices[0]->parsed_name().job == job.name() &&
403                 env_->local_devices[0]->parsed_name().task == task.first) {
404               // TODO(b/37868888): Remove this limitation when resolved
405               status = errors::InvalidArgument(
406                   "The ClusterSpec names the job and task index to be the same "
407                   "names that were provided when the server booted. This is "
408                   "currently not allowed. Job: ",
409                   job.name(), ", task index: ", task.first);
410               return;
411             }
412             worker_cache_factory_options.job_name = &job.name();
413             worker_cache_factory_options.task_index = task.first;
414           }
415         }
416       }
417 
418       // Create the worker cache from the computed server_def.
419       status = env_->worker_cache_factory(worker_cache_factory_options,
420                                           &worker_cache);
421       if (!status.ok()) return;
422       worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
423       // Ping all the workers and build the list of devices that the
424       // session will use.
425       status =
426           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
427                                          worker_cache, remote_devices.get());
428       if (!status.ok()) return;
429       device_set.reset(new DeviceSet);
430       for (auto&& d : *remote_devices) {
431         device_set->AddDevice(d.get());
432         DeviceNameUtils::ParsedName name = d->parsed_name();
433         if (name.job == *worker_cache_factory_options.job_name &&
434             name.task == worker_cache_factory_options.task_index &&
435             name.type == "CPU" && name.id == 0) {
436           device_set->set_client_device(d.get());
437         }
438       }
439     } else {
440       worker_cache = env_->worker_cache;
441       // Ping all the workers and build the list of devices that the
442       // session will use.
443       status =
444           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
445                                          worker_cache, remote_devices.get());
446       if (!status.ok()) return;
447       device_set.reset(new DeviceSet);
448       for (auto&& d : *remote_devices) {
449         device_set->AddDevice(d.get());
450       }
451       int num_local_devices = 0;
452       for (Device* d : env_->local_devices) {
453         device_set->AddDevice(d);
454         if (num_local_devices == 0) {
455           // Uses the first local device as the client device.
456           device_set->set_client_device(d);
457         }
458         num_local_devices++;
459       }
460     }
461 
462     CHECK(device_set->client_device()) << "No client device found. Missing "
463                                        << "CPU:0 device?";
464 
465     SessionOptions options;
466     options.config = req->config();
467 
468     std::vector<string> filtered_worker_list;
469     DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
470                                    worker_cache, &filtered_worker_list);
471 
472     MasterSession* session = env_->master_session_factory(
473         options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
474         std::move(device_set), std::move(filtered_worker_list));
475 
476     GraphDef* gdef =
477         const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
478 
479     status = session->Create(gdef, worker_cache_factory_options);
480     if (!status.ok()) {
481       session->Close().IgnoreError();
482       session->Unref();
483       return;
484     }
485     resp->set_session_handle(session->handle());
486     // Insert into the session map, which takes ownership of the session.
487     {
488       mutex_lock l(mu_);
489       CHECK(sessions_.insert({session->handle(), session}).second);
490     }
491   });
492 }
493 
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)494 void Master::ExtendSession(const ExtendSessionRequest* req,
495                            ExtendSessionResponse* resp, MyClosure done) {
496   auto session = FindMasterSession(req->session_handle());
497   if (session == nullptr) {
498     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
499     return;
500   }
501 
502   SchedClosure([session, req, resp, done]() {
503     Status status = ValidateExternalGraphDefSyntax(req->graph_def());
504     if (status.ok()) {
505       status = session->Extend(req, resp);
506     }
507     session->Unref();
508     done(status);
509   });
510 }
511 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)512 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
513                              PartialRunSetupResponse* resp, MyClosure done) {
514   Status s = recent_request_ids_.TrackUnique(req->request_id(),
515                                              "PartialRunSetup (Master)", *req);
516   if (!s.ok()) {
517     done(s);
518     return;
519   }
520   auto session = FindMasterSession(req->session_handle());
521   if (session == nullptr) {
522     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
523     return;
524   }
525 
526   SchedClosure([session, req, resp, done]() {
527     Status s = session->PartialRunSetup(req, resp);
528     session->Unref();
529     done(s);
530   });
531 }
532 
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)533 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
534                      MutableRunStepResponseWrapper* resp, MyClosure done) {
535   Status s = recent_request_ids_.TrackUnique(req->request_id(),
536                                              "RunStep (Master)", req);
537   if (!s.ok()) {
538     done(s);
539     return;
540   }
541   auto start_time = env_->env->NowMicros();
542   auto session = FindMasterSession(req->session_handle());
543   if (session == nullptr) {
544     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
545     return;
546   }
547 
548   SchedClosure([this, start_time, session, opts, req, resp, done]() {
549     Status status = session->Run(opts, *req, resp);
550     session->Unref();
551     uint64 done_time = env_->env->NowMicros();
552     done(status);
553     mutex_lock l(mu_);
554     last_1000_steps_.AddValue((done_time - start_time) / 1e9);
555     ++step_count_;
556   });
557 }
558 
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)559 void Master::CloseSession(const CloseSessionRequest* req,
560                           CloseSessionResponse* resp, MyClosure done) {
561   MasterSession* session = nullptr;
562   {
563     mu_.lock();
564     auto iter = sessions_.find(req->session_handle());
565     if (iter == sessions_.end()) {
566       mu_.unlock();
567       done(errors::Aborted(
568           "Session ", req->session_handle(),
569           " is not found. Possibly, this master has restarted."));
570       return;
571     }
572     // NOTE(mrry): One reference to the session is transferred from
573     // `sessions_[req->session_handle()]` to `session`.
574     session = iter->second;
575     sessions_.erase(iter);
576     mu_.unlock();
577   }
578 
579   // Session Close() blocks on thread shutdown. Therefore, we need to
580   // delete it in non-critical thread.
581   SchedClosure([session, done]() {
582     Status s = session->Close();
583     session->Unref();
584     done(s);
585   });
586 }
587 
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)588 void Master::ListDevices(const ListDevicesRequest* req,
589                          ListDevicesResponse* resp, MyClosure done) {
590   SchedClosure([this, req, resp, done]() {
591     if (!req->session_handle().empty()) {
592       auto session = FindMasterSession(req->session_handle());
593       if (session == nullptr) {
594         done(errors::InvalidArgument(
595             "Session ", req->session_handle(),
596             " is not found. Possibly, this master has restarted."));
597         return;
598       }
599       core::ScopedUnref ref(session);
600       Status s = session->ListDevices(resp);
601       done(s);
602       return;
603     }
604     std::vector<std::unique_ptr<Device>> remote_devices;
605     Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
606                                               &remote_devices);
607     if (s.ok()) {
608       for (Device* dev : env_->local_devices) {
609         *(resp->add_local_device()) = dev->attributes();
610       }
611       for (auto&& dev : remote_devices) {
612         *(resp->add_remote_device()) = dev->attributes();
613       }
614     }
615     done(s);
616   });
617 }
618 
CleanupWorkers(const ResetRequest & reset)619 void Master::CleanupWorkers(const ResetRequest& reset) {
620   std::vector<string> worker_names;
621   DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
622                                  env_->worker_cache, &worker_names);
623   if (!worker_names.empty()) {
624     const int num_workers = worker_names.size();
625     std::vector<Notification> n(num_workers);
626     CleanupAllRequest req;
627     (*req.mutable_container()) = reset.container();
628     std::vector<CleanupAllResponse> resp(num_workers);
629     int c = 0;
630     for (int i = 0; i < num_workers; ++i) {
631       const string& worker_name = worker_names[i];
632       auto worker = env_->worker_cache->CreateWorker(worker_name);
633       if (worker) {
634         worker->CleanupAllAsync(
635             &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
636               TF_CHECK_OK(s);
637               env_->worker_cache->ReleaseWorker(worker_name, worker);
638               n[c].Notify();
639             });
640       } else {
641         n[c].Notify();
642       }
643       ++c;
644     }
645     for (size_t i = 0; i < n.size(); ++i) {
646       n[i].WaitForNotification();
647     }
648   }
649 }
650 
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)651 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
652                    MyClosure done) {
653   // Vector to hold the session pointers present in the sessions_
654   // (string->Session*) map.
655   std::vector<MasterSession*> sessions_to_close;
656   {
657     mutex_lock l(mu_);
658     // NOTE(mrry): Transfer one reference to each session from the
659     // `sessions_` map to the `sessions_to_close` vector.
660     for (const auto& entry : sessions_) {
661       sessions_to_close.push_back(entry.second);
662     }
663     sessions_.clear();
664   }
665 
666   CleanupWorkers(*req);
667 
668   SchedClosure([sessions_to_close, done]() {
669     Status s;
670     for (MasterSession* session : sessions_to_close) {
671       s.Update(session->Close());
672       session->Unref();
673     }
674     done(s);
675   });
676 }
677 
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)678 void Master::MakeCallable(const MakeCallableRequest* req,
679                           MakeCallableResponse* resp, MyClosure done) {
680   Status s = recent_request_ids_.TrackUnique(req->request_id(),
681                                              "MakeCallable (Master)", *req);
682   if (!s.ok()) {
683     done(s);
684     return;
685   }
686   auto session = FindMasterSession(req->session_handle());
687   if (session == nullptr) {
688     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
689     return;
690   }
691 
692   SchedClosure(std::bind(
693       [session, req, resp](MyClosure done) {
694         Status s = session->MakeCallable(*req, resp);
695         session->Unref();
696         done(s);
697       },
698       std::move(done)));
699 }
700 
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)701 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
702                          RunCallableResponse* resp, MyClosure done) {
703   Status s = recent_request_ids_.TrackUnique(req->request_id(),
704                                              "RunCallable (Master)", *req);
705   if (!s.ok()) {
706     done(s);
707     return;
708   }
709   auto session = FindMasterSession(req->session_handle());
710   if (session == nullptr) {
711     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
712     return;
713   }
714 
715   SchedClosure(std::bind(
716       [session, opts, req, resp](MyClosure done) {
717         Status s = session->RunCallable(opts, *req, resp);
718         session->Unref();
719         done(s);
720       },
721       std::move(done)));
722 }
723 
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)724 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
725                              ReleaseCallableResponse* resp, MyClosure done) {
726   auto session = FindMasterSession(req->session_handle());
727   if (session == nullptr) {
728     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
729     return;
730   }
731 
732   SchedClosure(std::bind(
733       [session, req, resp](MyClosure done) {
734         Status s = session->ReleaseCallable(*req, resp);
735         session->Unref();
736         done(s);
737       },
738       std::move(done)));
739 }
740 
741 }  // end namespace tensorflow
742