• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Master implements the service MasterService.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31 
32 #include "tensorflow/core/distributed_runtime/master.h"
33 
34 #include <unordered_set>
35 #include <vector>
36 
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/protobuf/cluster.pb.h"
53 #include "tensorflow/core/protobuf/master.pb.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57 
58 namespace tensorflow {
59 
60 namespace {
61 const char* const kGrpcProtocol = "grpc://";
62 }  // namespace
63 
Master(MasterEnv * env,double session_gc_seconds)64 Master::Master(MasterEnv* env, double session_gc_seconds)
65     : env_(env),
66       last_1000_steps_(1000),
67       step_count_(0),
68       session_gc_seconds_(session_gc_seconds),
69       recent_request_ids_(10000) {
70   // Right now, a master service must be co-located with a device.
71   // Otherwise, fetches do not work.
72   CHECK(!env->local_devices.empty());
73 
74   if (session_gc_seconds_ > 0.0) {
75     gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
76                                         [this]() { GC(); });
77   } else {
78     gc_thread_ = nullptr;
79   }
80 }
81 
~Master()82 Master::~Master() {
83   if (gc_thread_) {
84     mutex_lock l(mu_);
85     shutdown_ = true;
86     shutdown_cv_.notify_all();
87     delete gc_thread_;
88   }
89 }
90 
GC()91 void Master::GC() {
92   Env* env = Env::Default();
93   while (true) {
94     mutex_lock l(mu_);
95     const int kTimeoutMilliseconds = 10 * 1000;  // 10 seconds.
96     WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
97     if (shutdown_) {
98       break;
99     }
100     std::vector<string> handles;
101     const int64_t num_micros =
102         static_cast<int64>(session_gc_seconds_ * 1000000);
103     for (const auto& entry : sessions_) {
104       int64_t lat = entry.second->last_access_time_usec();
105       if (static_cast<int64>(env->NowMicros()) - lat > num_micros) {
106         handles.push_back(entry.first);
107         auto* sess = entry.second;
108         SchedClosure([this, sess]() {
109           LOG(WARNING) << "GC session " << sess->handle() << " after "
110                        << session_gc_seconds_ << " seconds.  "
111                        << "Note that if you are starting multiple replicas "
112                        << "on a staggered delay, session_gc_seconds may need "
113                        << "to be raised.";
114           sess->GarbageCollect();
115         });
116       }
117     }
118     for (const auto& handle : handles) sessions_.erase(handle);
119   }
120 }
121 
FindMasterSession(const string & handle)122 MasterSession* Master::FindMasterSession(const string& handle) {
123   MasterSession* session = nullptr;
124   {
125     mutex_lock l(mu_);
126     session = gtl::FindPtrOrNull(sessions_, handle);
127     if (session != nullptr) {
128       session->Ref();
129     }
130   }
131   return session;
132 }
133 
134 class DeviceFinder {
135  public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)136   static Status GetRemoteDevices(
137       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
138       WorkerCacheInterface* worker_cache,
139       std::vector<std::unique_ptr<Device>>* out_remote) {
140     DeviceFinder finder(device_filters, env, worker_cache);
141     finder.Start();
142     TF_RETURN_IF_ERROR(finder.Wait());
143     finder.GetRemoteDevices(env->local_devices, out_remote);
144     return Status::OK();
145   }
146 
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)147   static void GetRemoteWorkers(
148       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
149       WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
150     DeviceFinder finder(device_filters, env, worker_cache);
151     *workers = finder.targets_;
152   }
153 
154  private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)155   explicit DeviceFinder(
156       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
157       WorkerCacheInterface* worker_cache)
158       : env_(env), worker_cache_(worker_cache) {
159     CHECK(worker_cache) << "Worker cache was null!";
160     auto process_filter = [this](const string& filter) {
161       DeviceNameUtils::ParsedName parsed;
162       if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
163         filters_.push_back(parsed);
164       } else {
165         LOG(FATAL) << "Skipping invalid filter: " << filter;
166       }
167     };
168     for (const string& filter : device_filters) {
169       process_filter(filter);
170     }
171     // Enumerates all known workers' target. A target name is a
172     // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
173     if (filters_.empty()) {
174       // If no filters were specified, we list all known workers in
175       // `worker_cache`.
176       std::vector<string> workers;
177       worker_cache->ListWorkers(&workers);
178       std::swap(workers, targets_);
179     } else {
180       // When applying filters, we must include the local worker, even if it
181       // does not match any of the filters.
182       CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
183       const string& local_device_name = env_->local_devices[0]->name();
184       DeviceNameUtils::ParsedName local_parsed_name;
185       CHECK(DeviceNameUtils::ParseFullName(local_device_name,
186                                            &local_parsed_name));
187       bool all_filters_have_job = true;
188       std::unordered_set<string> filter_job_names({local_parsed_name.job});
189       for (const DeviceNameUtils::ParsedName& filter : filters_) {
190         all_filters_have_job = all_filters_have_job && filter.has_job;
191         if (filter.has_job) {
192           filter_job_names.insert(filter.job);
193         }
194       }
195 
196       std::vector<string> workers;
197       if (all_filters_have_job) {
198         // If all of the device filters have a job specified, then we only need
199         // to list the workers in the jobs named in the filter, because a worker
200         // in any other job would not match any filter.
201         for (const string& job_name : filter_job_names) {
202           VLOG(2) << "Selectively listing workers in job: " << job_name;
203           std::vector<string> workers_in_job;
204           worker_cache->ListWorkersInJob(job_name, &workers_in_job);
205           workers.insert(workers.end(), workers_in_job.begin(),
206                          workers_in_job.end());
207         }
208       } else {
209         // If any of the device filters does not have a job specified, then we
210         // must list the workers from all jobs.
211         VLOG(2) << "Listing workers in all jobs because some device "
212                 << "filter has no job specified. Filters were:";
213         if (device_filters.empty()) {
214           VLOG(2) << "- <NO FILTERS>";
215         } else {
216           for (const string& filter : device_filters) {
217             VLOG(2) << "- " << filter;
218           }
219         }
220         worker_cache->ListWorkers(&workers);
221       }
222       for (const string& name : workers) {
223         if (MatchFilters(name) ||
224             DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
225           targets_.push_back(name);
226         }
227       }
228     }
229     seen_targets_.assign(targets_.size(), false);
230   }
231 
~DeviceFinder()232   ~DeviceFinder() {
233     for (Device* dev : found_) delete dev;
234   }
235 
Start()236   void Start() {
237     {
238       mutex_lock l(mu_);
239       num_pending_ = targets_.size();
240       if (num_pending_ == 0) {
241         pending_zero_.notify_all();
242       }
243     }
244     // Talk to all workers to get the list of available devices.
245     using std::placeholders::_1;
246     using std::placeholders::_2;
247     for (size_t i = 0; i < targets_.size(); ++i) {
248       // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
249       // never be called.
250       NewRemoteDevices(env_->env, worker_cache_, targets_[i],
251                        std::bind(&ME::WhenFound, this, i, _1, _2));
252     }
253   }
254 
255   // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
256   // to hear from workers, log a list of the workers who have not
257   // responded.
258   const int32 kLoggingPeriodMs = 10 * 1000;
259 
Wait()260   Status Wait() {
261     mutex_lock l(mu_);
262     // TODO(mrry): Propagate a timeout here, since `num_pending_` may
263     // never become zero.
264     while (num_pending_ != 0) {
265       pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
266       if (num_pending_ != 0) {
267         for (size_t i = 0; i < targets_.size(); ++i) {
268           if (!seen_targets_[i]) {
269             LOG(INFO)
270                 << "CreateSession still waiting for response from worker: "
271                 << targets_[i];
272           }
273         }
274       }
275     }
276     return status_;
277   }
278 
279   // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)280   void GetRemoteDevices(const std::vector<Device*>& local,
281                         std::vector<std::unique_ptr<Device>>* remote) {
282     std::unordered_set<string> names(local.size());
283     for (Device* dev : local) names.insert(dev->name());
284     mutex_lock l(mu_);
285     for (Device* dev : found_) {
286       const string& name = dev->name();
287       if (names.insert(name).second && MatchFilters(name)) {
288         remote->push_back(std::unique_ptr<Device>(dev));
289       } else {
290         delete dev;
291       }
292     }
293     found_.clear();
294   }
295 
296   typedef DeviceFinder ME;
297   const MasterEnv* env_;
298   WorkerCacheInterface* worker_cache_;
299   std::vector<DeviceNameUtils::ParsedName> filters_;
300 
301   mutex mu_;
302   int num_pending_ TF_GUARDED_BY(mu_);
303   condition_variable pending_zero_;
304   std::vector<Device*> found_ TF_GUARDED_BY(mu_);
305   // List of targets to be contacted by this DeviceFinder. The
306   // respective `bool` in `seen_targets_` indicates whether we have
307   // heard from this target or not.
308   std::vector<string> targets_;
309   std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
310   Status status_;
311 
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)312   void WhenFound(int target_index, const Status& s,
313                  std::vector<Device*>* devices) {
314     mutex_lock l(mu_);
315     seen_targets_[target_index] = true;
316     if (!s.ok()) {
317       LOG(ERROR) << "CreateSession failed because worker "
318                  << targets_[target_index] << " returned error: " << s;
319       status_.Update(s);
320     } else {
321       found_.insert(found_.end(), devices->begin(), devices->end());
322       devices->clear();
323     }
324     --num_pending_;
325     if (num_pending_ == 0) {
326       pending_zero_.notify_all();
327     }
328   }
329 
330   // Returns true iff the set of devices allowed by 'x' intersects
331   // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)332   bool Intersects(const DeviceNameUtils::ParsedName& x,
333                   const DeviceNameUtils::ParsedName& y) {
334     return (!x.has_job || !y.has_job || x.job == y.job) &&
335            (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
336            (!x.has_task || !y.has_task || x.task == y.task) &&
337            (!x.has_type || !y.has_type || x.type == y.type) &&
338            (!x.has_id || !y.has_id || x.id == y.id);
339   }
340 
341   // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)342   bool MatchFilters(const string& name) {
343     if (filters_.empty()) return true;
344     DeviceNameUtils::ParsedName x;
345     if (DeviceNameUtils::ParseFullName(name, &x)) {
346       for (const auto& filter : filters_) {
347         if (Intersects(x, filter)) return true;
348       }
349     }
350     return false;
351   }
352 
353   TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
354 };
355 
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)356 void Master::CreateSession(const CreateSessionRequest* req,
357                            CreateSessionResponse* resp, MyClosure done) {
358   SchedClosure([this, req, resp, done]() {
359     Status status;
360     WorkerCacheFactoryOptions worker_cache_factory_options;
361     string grpc_protocol("grpc");
362     worker_cache_factory_options.protocol = &grpc_protocol;
363     auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
364     status = ValidateExternalGraphDefSyntax(req->graph_def());
365     if (!status.ok()) return;
366 
367     // The following 4 variables are set differently, depending on whether this
368     // session uses a client-provided clusterspec or not.
369     WorkerCacheInterface* worker_cache = nullptr;
370     // Note: worker_cache_ptr will be null except if this session is using a
371     // client-supplied ClusterDef (ClusterSpec propagation).
372     std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
373     std::unique_ptr<DeviceSet> device_set;
374     // TODO(saeta): Convert to std::make_unique when available.
375     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
376         new std::vector<std::unique_ptr<Device>>());
377 
378     if (req->config().has_cluster_def()) {
379       worker_cache_factory_options.cluster_def = &req->config().cluster_def();
380 
381       // Set the server_def's job_name and task_index fields.
382       string normalized_string;
383       string grpc_protocol(kGrpcProtocol);
384       if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
385           0) {
386         normalized_string =
387             req->target().substr(grpc_protocol.length(), string::npos);
388       } else {
389         normalized_string = req->target();
390       }
391       for (auto&& job : req->config().cluster_def().job()) {
392         for (auto&& task : job.tasks()) {
393           if (task.second == normalized_string) {
394             if (worker_cache_factory_options.job_name != nullptr) {
395               status = errors::InvalidArgument(
396                   "Found multiple matching tasks that correspond to "
397                   "to the master. Master target: '",
398                   req->target(), "'. ClusterDef: ",
399                   req->config().cluster_def().ShortDebugString());
400               LOG(ERROR) << status;
401               return;
402             }
403             if (env_->local_devices[0]->parsed_name().job == job.name() &&
404                 env_->local_devices[0]->parsed_name().task == task.first) {
405               // TODO(b/37868888): Remove this limitation when resolved
406               status = errors::InvalidArgument(
407                   "The ClusterSpec names the job and task index to be the same "
408                   "names that were provided when the server booted. This is "
409                   "currently not allowed. Job: ",
410                   job.name(), ", task index: ", task.first);
411               return;
412             }
413             worker_cache_factory_options.job_name = &job.name();
414             worker_cache_factory_options.task_index = task.first;
415           }
416         }
417       }
418       worker_cache_factory_options.rpc_options = &req->config().rpc_options();
419       // Create the worker cache from the computed server_def.
420       status = env_->worker_cache_factory(worker_cache_factory_options,
421                                           &worker_cache);
422       if (!status.ok()) return;
423       worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
424       // Ping all the workers and build the list of devices that the
425       // session will use.
426       status =
427           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
428                                          worker_cache, remote_devices.get());
429       if (!status.ok()) return;
430       device_set.reset(new DeviceSet);
431       for (auto&& d : *remote_devices) {
432         device_set->AddDevice(d.get());
433         DeviceNameUtils::ParsedName name = d->parsed_name();
434         if (name.job == *worker_cache_factory_options.job_name &&
435             name.task == worker_cache_factory_options.task_index &&
436             name.type == "CPU" && name.id == 0) {
437           device_set->set_client_device(d.get());
438         }
439       }
440     } else {
441       worker_cache = env_->worker_cache;
442       // Ping all the workers and build the list of devices that the
443       // session will use.
444       status =
445           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
446                                          worker_cache, remote_devices.get());
447       if (!status.ok()) return;
448       device_set.reset(new DeviceSet);
449       for (auto&& d : *remote_devices) {
450         device_set->AddDevice(d.get());
451       }
452       int num_local_devices = 0;
453       for (Device* d : env_->local_devices) {
454         device_set->AddDevice(d);
455         if (num_local_devices == 0) {
456           // Uses the first local device as the client device.
457           device_set->set_client_device(d);
458         }
459         num_local_devices++;
460       }
461     }
462 
463     CHECK(device_set->client_device()) << "No client device found. Missing "
464                                        << "CPU:0 device?";
465 
466     SessionOptions options;
467     options.config = req->config();
468 
469     std::vector<string> filtered_worker_list;
470     DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
471                                    worker_cache, &filtered_worker_list);
472 
473     MasterSession* session = env_->master_session_factory(
474         options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
475         std::move(device_set), std::move(filtered_worker_list));
476 
477     GraphDef* gdef =
478         const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
479 
480     status = session->Create(std::move(*gdef), worker_cache_factory_options);
481     if (!status.ok()) {
482       session->Close().IgnoreError();
483       session->Unref();
484       return;
485     }
486     resp->set_session_handle(session->handle());
487     // Insert into the session map, which takes ownership of the session.
488     {
489       mutex_lock l(mu_);
490       CHECK(sessions_.insert({session->handle(), session}).second);
491     }
492   });
493 }
494 
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)495 void Master::ExtendSession(const ExtendSessionRequest* req,
496                            ExtendSessionResponse* resp, MyClosure done) {
497   auto session = FindMasterSession(req->session_handle());
498   if (session == nullptr) {
499     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
500     return;
501   }
502 
503   SchedClosure([session, req, resp, done]() {
504     Status status = ValidateExternalGraphDefSyntax(req->graph_def());
505     if (status.ok()) {
506       status = session->Extend(req, resp);
507     }
508     session->Unref();
509     done(status);
510   });
511 }
512 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)513 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
514                              PartialRunSetupResponse* resp, MyClosure done) {
515   Status s = recent_request_ids_.TrackUnique(req->request_id(),
516                                              "PartialRunSetup (Master)", *req);
517   if (!s.ok()) {
518     done(s);
519     return;
520   }
521   auto session = FindMasterSession(req->session_handle());
522   if (session == nullptr) {
523     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
524     return;
525   }
526 
527   SchedClosure([session, req, resp, done]() {
528     Status s = session->PartialRunSetup(req, resp);
529     session->Unref();
530     done(s);
531   });
532 }
533 
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)534 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
535                      MutableRunStepResponseWrapper* resp, MyClosure done) {
536   Status s = recent_request_ids_.TrackUnique(req->request_id(),
537                                              "RunStep (Master)", req);
538   if (!s.ok()) {
539     done(s);
540     return;
541   }
542   auto start_time = env_->env->NowMicros();
543   auto session = FindMasterSession(req->session_handle());
544   if (session == nullptr) {
545     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
546     return;
547   }
548 
549   SchedClosure([this, start_time, session, opts, req, resp, done]() {
550     Status status = session->Run(opts, *req, resp);
551     session->Unref();
552     uint64 done_time = env_->env->NowMicros();
553     done(status);
554     mutex_lock l(mu_);
555     last_1000_steps_.AddValue((done_time - start_time) / 1e9);
556     ++step_count_;
557   });
558 }
559 
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)560 void Master::CloseSession(const CloseSessionRequest* req,
561                           CloseSessionResponse* resp, MyClosure done) {
562   MasterSession* session = nullptr;
563   {
564     mu_.lock();
565     auto iter = sessions_.find(req->session_handle());
566     if (iter == sessions_.end()) {
567       mu_.unlock();
568       done(errors::Aborted(
569           "Session ", req->session_handle(),
570           " is not found. Possibly, this master has restarted."));
571       return;
572     }
573     // NOTE(mrry): One reference to the session is transferred from
574     // `sessions_[req->session_handle()]` to `session`.
575     session = iter->second;
576     sessions_.erase(iter);
577     mu_.unlock();
578   }
579 
580   // Session Close() blocks on thread shutdown. Therefore, we need to
581   // delete it in non-critical thread.
582   SchedClosure([session, done]() {
583     Status s = session->Close();
584     session->Unref();
585     done(s);
586   });
587 }
588 
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)589 void Master::ListDevices(const ListDevicesRequest* req,
590                          ListDevicesResponse* resp, MyClosure done) {
591   SchedClosure([this, req, resp, done]() {
592     if (!req->session_handle().empty()) {
593       auto session = FindMasterSession(req->session_handle());
594       if (session == nullptr) {
595         done(errors::InvalidArgument(
596             "Session ", req->session_handle(),
597             " is not found. Possibly, this master has restarted."));
598         return;
599       }
600       core::ScopedUnref ref(session);
601       Status s = session->ListDevices(resp);
602       done(s);
603       return;
604     }
605     std::vector<std::unique_ptr<Device>> remote_devices;
606     Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
607                                               &remote_devices);
608     if (s.ok()) {
609       for (Device* dev : env_->local_devices) {
610         *(resp->add_local_device()) = dev->attributes();
611       }
612       for (auto&& dev : remote_devices) {
613         *(resp->add_remote_device()) = dev->attributes();
614       }
615     }
616     done(s);
617   });
618 }
619 
CleanupWorkers(const ResetRequest & reset)620 void Master::CleanupWorkers(const ResetRequest& reset) {
621   std::vector<string> worker_names;
622   DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
623                                  env_->worker_cache, &worker_names);
624   if (!worker_names.empty()) {
625     const int num_workers = worker_names.size();
626     std::vector<Notification> n(num_workers);
627     CleanupAllRequest req;
628     (*req.mutable_container()) = reset.container();
629     std::vector<CleanupAllResponse> resp(num_workers);
630     int c = 0;
631     for (int i = 0; i < num_workers; ++i) {
632       const string& worker_name = worker_names[i];
633       auto worker = env_->worker_cache->GetOrCreateWorker(worker_name);
634       if (worker) {
635         worker->CleanupAllAsync(
636             &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
637               TF_CHECK_OK(s);
638               env_->worker_cache->ReleaseWorker(worker_name, worker);
639               n[c].Notify();
640             });
641       } else {
642         n[c].Notify();
643       }
644       ++c;
645     }
646     for (size_t i = 0; i < n.size(); ++i) {
647       n[i].WaitForNotification();
648     }
649   }
650 }
651 
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)652 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
653                    MyClosure done) {
654   // Vector to hold the session pointers present in the sessions_
655   // (string->Session*) map.
656   std::vector<MasterSession*> sessions_to_close;
657   {
658     mutex_lock l(mu_);
659     // NOTE(mrry): Transfer one reference to each session from the
660     // `sessions_` map to the `sessions_to_close` vector.
661     for (const auto& entry : sessions_) {
662       sessions_to_close.push_back(entry.second);
663     }
664     sessions_.clear();
665   }
666 
667   CleanupWorkers(*req);
668 
669   SchedClosure([sessions_to_close, done]() {
670     Status s;
671     for (MasterSession* session : sessions_to_close) {
672       s.Update(session->Close());
673       session->Unref();
674     }
675     done(s);
676   });
677 }
678 
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)679 void Master::MakeCallable(const MakeCallableRequest* req,
680                           MakeCallableResponse* resp, MyClosure done) {
681   Status s = recent_request_ids_.TrackUnique(req->request_id(),
682                                              "MakeCallable (Master)", *req);
683   if (!s.ok()) {
684     done(s);
685     return;
686   }
687   auto session = FindMasterSession(req->session_handle());
688   if (session == nullptr) {
689     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
690     return;
691   }
692 
693   SchedClosure([session, req, resp, done = std::move(done)]() {
694     Status s = session->MakeCallable(*req, resp);
695     session->Unref();
696     done(s);
697   });
698 }
699 
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)700 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
701                          RunCallableResponse* resp, MyClosure done) {
702   Status s = recent_request_ids_.TrackUnique(req->request_id(),
703                                              "RunCallable (Master)", *req);
704   if (!s.ok()) {
705     done(s);
706     return;
707   }
708   auto session = FindMasterSession(req->session_handle());
709   if (session == nullptr) {
710     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
711     return;
712   }
713 
714   SchedClosure([session, opts, req, resp, done = std::move(done)]() {
715     Status s = session->RunCallable(opts, *req, resp);
716     session->Unref();
717     done(s);
718   });
719 }
720 
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)721 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
722                              ReleaseCallableResponse* resp, MyClosure done) {
723   auto session = FindMasterSession(req->session_handle());
724   if (session == nullptr) {
725     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
726     return;
727   }
728 
729   SchedClosure([session, req, resp, done = std::move(done)]() {
730     Status s = session->ReleaseCallable(*req, resp);
731     session->Unref();
732     done(s);
733   });
734 }
735 
736 }  // end namespace tensorflow
737