1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Master implements the service MasterSerivce.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31
32 #include "tensorflow/core/distributed_runtime/master.h"
33
34 #include <unordered_set>
35 #include <vector>
36
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/protobuf/cluster.pb.h"
53 #include "tensorflow/core/protobuf/master.pb.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57
58 namespace tensorflow {
59
60 namespace {
61 const char* const kGrpcProtocol = "grpc://";
62 } // namespace
63
Master(MasterEnv * env,double session_gc_seconds)64 Master::Master(MasterEnv* env, double session_gc_seconds)
65 : env_(env),
66 last_1000_steps_(1000),
67 step_count_(0),
68 session_gc_seconds_(session_gc_seconds),
69 recent_request_ids_(10000) {
70 // Right now, a master service must be co-located with a device.
71 // Otherwise, fetches do not work.
72 CHECK(!env->local_devices.empty());
73
74 if (session_gc_seconds_ > 0.0) {
75 gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
76 [this]() { GC(); });
77 } else {
78 gc_thread_ = nullptr;
79 }
80 }
81
~Master()82 Master::~Master() {
83 if (gc_thread_) {
84 mutex_lock l(mu_);
85 shutdown_ = true;
86 shutdown_cv_.notify_all();
87 delete gc_thread_;
88 }
89 }
90
GC()91 void Master::GC() {
92 Env* env = Env::Default();
93 while (true) {
94 mutex_lock l(mu_);
95 const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
96 WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
97 if (shutdown_) {
98 break;
99 }
100 std::vector<string> handles;
101 const int64 num_micros = static_cast<int64>(session_gc_seconds_ * 1000000);
102 for (const auto& entry : sessions_) {
103 int64 lat = entry.second->last_access_time_usec();
104 if (static_cast<int64>(env->NowMicros()) - lat > num_micros) {
105 handles.push_back(entry.first);
106 auto* sess = entry.second;
107 SchedClosure([this, sess]() {
108 LOG(WARNING) << "GC session " << sess->handle() << " after "
109 << session_gc_seconds_ << " seconds. "
110 << "Note that if you are starting multiple replicas "
111 << "on a staggered delay, session_gc_seconds may need "
112 << "to be raised.";
113 sess->GarbageCollect();
114 });
115 }
116 }
117 for (const auto& handle : handles) sessions_.erase(handle);
118 }
119 }
120
FindMasterSession(const string & handle)121 MasterSession* Master::FindMasterSession(const string& handle) {
122 MasterSession* session = nullptr;
123 {
124 mutex_lock l(mu_);
125 session = gtl::FindPtrOrNull(sessions_, handle);
126 if (session != nullptr) {
127 session->Ref();
128 }
129 }
130 return session;
131 }
132
133 class DeviceFinder {
134 public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)135 static Status GetRemoteDevices(
136 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
137 WorkerCacheInterface* worker_cache,
138 std::vector<std::unique_ptr<Device>>* out_remote) {
139 DeviceFinder finder(device_filters, env, worker_cache);
140 finder.Start();
141 TF_RETURN_IF_ERROR(finder.Wait());
142 finder.GetRemoteDevices(env->local_devices, out_remote);
143 return Status::OK();
144 }
145
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)146 static void GetRemoteWorkers(
147 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
148 WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
149 DeviceFinder finder(device_filters, env, worker_cache);
150 *workers = finder.targets_;
151 }
152
153 private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)154 explicit DeviceFinder(
155 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
156 WorkerCacheInterface* worker_cache)
157 : env_(env), worker_cache_(worker_cache) {
158 CHECK(worker_cache) << "Worker cache was null!";
159 auto process_filter = [this](const string& filter) {
160 DeviceNameUtils::ParsedName parsed;
161 if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
162 filters_.push_back(parsed);
163 } else {
164 LOG(FATAL) << "Skipping invalid filter: " << filter;
165 }
166 };
167 for (const string& filter : device_filters) {
168 process_filter(filter);
169 }
170 // Enumerates all known workers' target. A target name is a
171 // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
172 if (filters_.empty()) {
173 // If no filters were specified, we list all known workers in
174 // `worker_cache`.
175 std::vector<string> workers;
176 worker_cache->ListWorkers(&workers);
177 std::swap(workers, targets_);
178 } else {
179 // When applying filters, we must include the local worker, even if it
180 // does not match any of the filters.
181 CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
182 const string& local_device_name = env_->local_devices[0]->name();
183 DeviceNameUtils::ParsedName local_parsed_name;
184 CHECK(DeviceNameUtils::ParseFullName(local_device_name,
185 &local_parsed_name));
186 bool all_filters_have_job = true;
187 std::unordered_set<string> filter_job_names({local_parsed_name.job});
188 for (const DeviceNameUtils::ParsedName& filter : filters_) {
189 all_filters_have_job = all_filters_have_job && filter.has_job;
190 if (filter.has_job) {
191 filter_job_names.insert(filter.job);
192 }
193 }
194
195 std::vector<string> workers;
196 if (all_filters_have_job) {
197 // If all of the device filters have a job specified, then we only need
198 // to list the workers in the jobs named in the filter, because a worker
199 // in any other job would not match any filter.
200 for (const string& job_name : filter_job_names) {
201 VLOG(2) << "Selectively listing workers in job: " << job_name;
202 std::vector<string> workers_in_job;
203 worker_cache->ListWorkersInJob(job_name, &workers_in_job);
204 workers.insert(workers.end(), workers_in_job.begin(),
205 workers_in_job.end());
206 }
207 } else {
208 // If any of the device filters does not have a job specified, then we
209 // must list the workers from all jobs.
210 VLOG(2) << "Listing workers in all jobs because some device "
211 << "filter has no job specified. Filters were:";
212 if (device_filters.empty()) {
213 VLOG(2) << "- <NO FILTERS>";
214 } else {
215 for (const string& filter : device_filters) {
216 VLOG(2) << "- " << filter;
217 }
218 }
219 worker_cache->ListWorkers(&workers);
220 }
221 for (const string& name : workers) {
222 if (MatchFilters(name) ||
223 DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
224 targets_.push_back(name);
225 }
226 }
227 }
228 seen_targets_.assign(targets_.size(), false);
229 }
230
~DeviceFinder()231 ~DeviceFinder() {
232 for (Device* dev : found_) delete dev;
233 }
234
Start()235 void Start() {
236 {
237 mutex_lock l(mu_);
238 num_pending_ = targets_.size();
239 if (num_pending_ == 0) {
240 pending_zero_.notify_all();
241 }
242 }
243 // Talk to all workers to get the list of available devices.
244 using std::placeholders::_1;
245 using std::placeholders::_2;
246 for (size_t i = 0; i < targets_.size(); ++i) {
247 // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
248 // never be called.
249 NewRemoteDevices(env_->env, worker_cache_, targets_[i],
250 std::bind(&ME::WhenFound, this, i, _1, _2));
251 }
252 }
253
254 // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
255 // to hear from workers, log a list of the workers who have not
256 // responded.
257 const int32 kLoggingPeriodMs = 10 * 1000;
258
Wait()259 Status Wait() {
260 mutex_lock l(mu_);
261 // TODO(mrry): Propagate a timeout here, since `num_pending_` may
262 // never become zero.
263 while (num_pending_ != 0) {
264 pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
265 if (num_pending_ != 0) {
266 for (size_t i = 0; i < targets_.size(); ++i) {
267 if (!seen_targets_[i]) {
268 LOG(INFO)
269 << "CreateSession still waiting for response from worker: "
270 << targets_[i];
271 }
272 }
273 }
274 }
275 return status_;
276 }
277
278 // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)279 void GetRemoteDevices(const std::vector<Device*>& local,
280 std::vector<std::unique_ptr<Device>>* remote) {
281 std::unordered_set<string> names(local.size());
282 for (Device* dev : local) names.insert(dev->name());
283 mutex_lock l(mu_);
284 for (Device* dev : found_) {
285 const string& name = dev->name();
286 if (names.insert(name).second && MatchFilters(name)) {
287 remote->push_back(std::unique_ptr<Device>(dev));
288 } else {
289 delete dev;
290 }
291 }
292 found_.clear();
293 }
294
295 typedef DeviceFinder ME;
296 const MasterEnv* env_;
297 WorkerCacheInterface* worker_cache_;
298 std::vector<DeviceNameUtils::ParsedName> filters_;
299
300 mutex mu_;
301 int num_pending_ GUARDED_BY(mu_);
302 condition_variable pending_zero_;
303 std::vector<Device*> found_ GUARDED_BY(mu_);
304 // List of targets to be contacted by this DeviceFinder. The
305 // respective `bool` in `seen_targets_` indicates whether we have
306 // heard from this target or not.
307 std::vector<string> targets_;
308 std::vector<bool> seen_targets_ GUARDED_BY(mu_);
309 Status status_;
310
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)311 void WhenFound(int target_index, const Status& s,
312 std::vector<Device*>* devices) {
313 mutex_lock l(mu_);
314 seen_targets_[target_index] = true;
315 if (!s.ok()) {
316 LOG(ERROR) << "CreateSession failed because worker "
317 << targets_[target_index] << " returned error: " << s;
318 status_.Update(s);
319 } else {
320 found_.insert(found_.end(), devices->begin(), devices->end());
321 devices->clear();
322 }
323 --num_pending_;
324 if (num_pending_ == 0) {
325 pending_zero_.notify_all();
326 }
327 }
328
329 // Returns true iff the set of devices allowed by 'x' intersects
330 // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)331 bool Intersects(const DeviceNameUtils::ParsedName& x,
332 const DeviceNameUtils::ParsedName& y) {
333 return (!x.has_job || !y.has_job || x.job == y.job) &&
334 (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
335 (!x.has_task || !y.has_task || x.task == y.task) &&
336 (!x.has_type || !y.has_type || x.type == y.type) &&
337 (!x.has_id || !y.has_id || x.id == y.id);
338 }
339
340 // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)341 bool MatchFilters(const string& name) {
342 if (filters_.empty()) return true;
343 DeviceNameUtils::ParsedName x;
344 if (DeviceNameUtils::ParseFullName(name, &x)) {
345 for (const auto& filter : filters_) {
346 if (Intersects(x, filter)) return true;
347 }
348 }
349 return false;
350 }
351
352 TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
353 };
354
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)355 void Master::CreateSession(const CreateSessionRequest* req,
356 CreateSessionResponse* resp, MyClosure done) {
357 SchedClosure([this, req, resp, done]() {
358 Status status;
359 WorkerCacheFactoryOptions worker_cache_factory_options;
360 string grpc_protocol("grpc");
361 worker_cache_factory_options.protocol = &grpc_protocol;
362 auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
363 status = ValidateExternalGraphDefSyntax(req->graph_def());
364 if (!status.ok()) return;
365
366 // The following 4 variables are set differently, depending on whether this
367 // session uses a client-provided clusterspec or not.
368 WorkerCacheInterface* worker_cache = nullptr;
369 // Note: worker_cache_ptr will be null except if this session is using a
370 // client-supplied ClusterDef (ClusterSpec propagation).
371 std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
372 std::unique_ptr<DeviceSet> device_set;
373 // TODO(saeta): Convert to std::make_unique when available.
374 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
375 new std::vector<std::unique_ptr<Device>>());
376
377 if (req->config().has_cluster_def()) {
378 worker_cache_factory_options.cluster_def = &req->config().cluster_def();
379
380 // Set the server_def's job_name and task_index fields.
381 string normalized_string;
382 string grpc_protocol(kGrpcProtocol);
383 if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
384 0) {
385 normalized_string =
386 req->target().substr(grpc_protocol.length(), string::npos);
387 } else {
388 normalized_string = req->target();
389 }
390 for (auto&& job : req->config().cluster_def().job()) {
391 for (auto&& task : job.tasks()) {
392 if (task.second == normalized_string) {
393 if (worker_cache_factory_options.job_name != nullptr) {
394 status = errors::InvalidArgument(
395 "Found multiple matching tasks that correspond to "
396 "to the master. Master target: '",
397 req->target(), "'. ClusterDef: ",
398 req->config().cluster_def().ShortDebugString());
399 LOG(ERROR) << status;
400 return;
401 }
402 if (env_->local_devices[0]->parsed_name().job == job.name() &&
403 env_->local_devices[0]->parsed_name().task == task.first) {
404 // TODO(b/37868888): Remove this limitation when resolved
405 status = errors::InvalidArgument(
406 "The ClusterSpec names the job and task index to be the same "
407 "names that were provided when the server booted. This is "
408 "currently not allowed. Job: ",
409 job.name(), ", task index: ", task.first);
410 return;
411 }
412 worker_cache_factory_options.job_name = &job.name();
413 worker_cache_factory_options.task_index = task.first;
414 }
415 }
416 }
417
418 // Create the worker cache from the computed server_def.
419 status = env_->worker_cache_factory(worker_cache_factory_options,
420 &worker_cache);
421 if (!status.ok()) return;
422 worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
423 // Ping all the workers and build the list of devices that the
424 // session will use.
425 status =
426 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
427 worker_cache, remote_devices.get());
428 if (!status.ok()) return;
429 device_set.reset(new DeviceSet);
430 for (auto&& d : *remote_devices) {
431 device_set->AddDevice(d.get());
432 DeviceNameUtils::ParsedName name = d->parsed_name();
433 if (name.job == *worker_cache_factory_options.job_name &&
434 name.task == worker_cache_factory_options.task_index &&
435 name.type == "CPU" && name.id == 0) {
436 device_set->set_client_device(d.get());
437 }
438 }
439 } else {
440 worker_cache = env_->worker_cache;
441 // Ping all the workers and build the list of devices that the
442 // session will use.
443 status =
444 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
445 worker_cache, remote_devices.get());
446 if (!status.ok()) return;
447 device_set.reset(new DeviceSet);
448 for (auto&& d : *remote_devices) {
449 device_set->AddDevice(d.get());
450 }
451 int num_local_devices = 0;
452 for (Device* d : env_->local_devices) {
453 device_set->AddDevice(d);
454 if (num_local_devices == 0) {
455 // Uses the first local device as the client device.
456 device_set->set_client_device(d);
457 }
458 num_local_devices++;
459 }
460 }
461
462 CHECK(device_set->client_device()) << "No client device found. Missing "
463 << "CPU:0 device?";
464
465 SessionOptions options;
466 options.config = req->config();
467
468 std::vector<string> filtered_worker_list;
469 DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
470 worker_cache, &filtered_worker_list);
471
472 MasterSession* session = env_->master_session_factory(
473 options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
474 std::move(device_set), std::move(filtered_worker_list));
475
476 GraphDef* gdef =
477 const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
478
479 status = session->Create(gdef, worker_cache_factory_options);
480 if (!status.ok()) {
481 session->Close().IgnoreError();
482 session->Unref();
483 return;
484 }
485 resp->set_session_handle(session->handle());
486 // Insert into the session map, which takes ownership of the session.
487 {
488 mutex_lock l(mu_);
489 CHECK(sessions_.insert({session->handle(), session}).second);
490 }
491 });
492 }
493
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)494 void Master::ExtendSession(const ExtendSessionRequest* req,
495 ExtendSessionResponse* resp, MyClosure done) {
496 auto session = FindMasterSession(req->session_handle());
497 if (session == nullptr) {
498 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
499 return;
500 }
501
502 SchedClosure([session, req, resp, done]() {
503 Status status = ValidateExternalGraphDefSyntax(req->graph_def());
504 if (status.ok()) {
505 status = session->Extend(req, resp);
506 }
507 session->Unref();
508 done(status);
509 });
510 }
511
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)512 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
513 PartialRunSetupResponse* resp, MyClosure done) {
514 Status s = recent_request_ids_.TrackUnique(req->request_id(),
515 "PartialRunSetup (Master)", *req);
516 if (!s.ok()) {
517 done(s);
518 return;
519 }
520 auto session = FindMasterSession(req->session_handle());
521 if (session == nullptr) {
522 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
523 return;
524 }
525
526 SchedClosure([session, req, resp, done]() {
527 Status s = session->PartialRunSetup(req, resp);
528 session->Unref();
529 done(s);
530 });
531 }
532
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)533 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
534 MutableRunStepResponseWrapper* resp, MyClosure done) {
535 Status s = recent_request_ids_.TrackUnique(req->request_id(),
536 "RunStep (Master)", req);
537 if (!s.ok()) {
538 done(s);
539 return;
540 }
541 auto start_time = env_->env->NowMicros();
542 auto session = FindMasterSession(req->session_handle());
543 if (session == nullptr) {
544 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
545 return;
546 }
547
548 SchedClosure([this, start_time, session, opts, req, resp, done]() {
549 Status status = session->Run(opts, *req, resp);
550 session->Unref();
551 uint64 done_time = env_->env->NowMicros();
552 done(status);
553 mutex_lock l(mu_);
554 last_1000_steps_.AddValue((done_time - start_time) / 1e9);
555 ++step_count_;
556 });
557 }
558
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)559 void Master::CloseSession(const CloseSessionRequest* req,
560 CloseSessionResponse* resp, MyClosure done) {
561 MasterSession* session = nullptr;
562 {
563 mu_.lock();
564 auto iter = sessions_.find(req->session_handle());
565 if (iter == sessions_.end()) {
566 mu_.unlock();
567 done(errors::Aborted(
568 "Session ", req->session_handle(),
569 " is not found. Possibly, this master has restarted."));
570 return;
571 }
572 // NOTE(mrry): One reference to the session is transferred from
573 // `sessions_[req->session_handle()]` to `session`.
574 session = iter->second;
575 sessions_.erase(iter);
576 mu_.unlock();
577 }
578
579 // Session Close() blocks on thread shutdown. Therefore, we need to
580 // delete it in non-critical thread.
581 SchedClosure([session, done]() {
582 Status s = session->Close();
583 session->Unref();
584 done(s);
585 });
586 }
587
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)588 void Master::ListDevices(const ListDevicesRequest* req,
589 ListDevicesResponse* resp, MyClosure done) {
590 SchedClosure([this, req, resp, done]() {
591 if (!req->session_handle().empty()) {
592 auto session = FindMasterSession(req->session_handle());
593 if (session == nullptr) {
594 done(errors::InvalidArgument(
595 "Session ", req->session_handle(),
596 " is not found. Possibly, this master has restarted."));
597 return;
598 }
599 core::ScopedUnref ref(session);
600 Status s = session->ListDevices(resp);
601 done(s);
602 return;
603 }
604 std::vector<std::unique_ptr<Device>> remote_devices;
605 Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
606 &remote_devices);
607 if (s.ok()) {
608 for (Device* dev : env_->local_devices) {
609 *(resp->add_local_device()) = dev->attributes();
610 }
611 for (auto&& dev : remote_devices) {
612 *(resp->add_remote_device()) = dev->attributes();
613 }
614 }
615 done(s);
616 });
617 }
618
CleanupWorkers(const ResetRequest & reset)619 void Master::CleanupWorkers(const ResetRequest& reset) {
620 std::vector<string> worker_names;
621 DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
622 env_->worker_cache, &worker_names);
623 if (!worker_names.empty()) {
624 const int num_workers = worker_names.size();
625 std::vector<Notification> n(num_workers);
626 CleanupAllRequest req;
627 (*req.mutable_container()) = reset.container();
628 std::vector<CleanupAllResponse> resp(num_workers);
629 int c = 0;
630 for (int i = 0; i < num_workers; ++i) {
631 const string& worker_name = worker_names[i];
632 auto worker = env_->worker_cache->CreateWorker(worker_name);
633 if (worker) {
634 worker->CleanupAllAsync(
635 &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
636 TF_CHECK_OK(s);
637 env_->worker_cache->ReleaseWorker(worker_name, worker);
638 n[c].Notify();
639 });
640 } else {
641 n[c].Notify();
642 }
643 ++c;
644 }
645 for (size_t i = 0; i < n.size(); ++i) {
646 n[i].WaitForNotification();
647 }
648 }
649 }
650
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)651 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
652 MyClosure done) {
653 // Vector to hold the session pointers present in the sessions_
654 // (string->Session*) map.
655 std::vector<MasterSession*> sessions_to_close;
656 {
657 mutex_lock l(mu_);
658 // NOTE(mrry): Transfer one reference to each session from the
659 // `sessions_` map to the `sessions_to_close` vector.
660 for (const auto& entry : sessions_) {
661 sessions_to_close.push_back(entry.second);
662 }
663 sessions_.clear();
664 }
665
666 CleanupWorkers(*req);
667
668 SchedClosure([sessions_to_close, done]() {
669 Status s;
670 for (MasterSession* session : sessions_to_close) {
671 s.Update(session->Close());
672 session->Unref();
673 }
674 done(s);
675 });
676 }
677
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)678 void Master::MakeCallable(const MakeCallableRequest* req,
679 MakeCallableResponse* resp, MyClosure done) {
680 Status s = recent_request_ids_.TrackUnique(req->request_id(),
681 "MakeCallable (Master)", *req);
682 if (!s.ok()) {
683 done(s);
684 return;
685 }
686 auto session = FindMasterSession(req->session_handle());
687 if (session == nullptr) {
688 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
689 return;
690 }
691
692 SchedClosure(std::bind(
693 [session, req, resp](MyClosure done) {
694 Status s = session->MakeCallable(*req, resp);
695 session->Unref();
696 done(s);
697 },
698 std::move(done)));
699 }
700
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)701 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
702 RunCallableResponse* resp, MyClosure done) {
703 Status s = recent_request_ids_.TrackUnique(req->request_id(),
704 "RunCallable (Master)", *req);
705 if (!s.ok()) {
706 done(s);
707 return;
708 }
709 auto session = FindMasterSession(req->session_handle());
710 if (session == nullptr) {
711 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
712 return;
713 }
714
715 SchedClosure(std::bind(
716 [session, opts, req, resp](MyClosure done) {
717 Status s = session->RunCallable(opts, *req, resp);
718 session->Unref();
719 done(s);
720 },
721 std::move(done)));
722 }
723
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)724 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
725 ReleaseCallableResponse* resp, MyClosure done) {
726 auto session = FindMasterSession(req->session_handle());
727 if (session == nullptr) {
728 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
729 return;
730 }
731
732 SchedClosure(std::bind(
733 [session, req, resp](MyClosure done) {
734 Status s = session->ReleaseCallable(*req, resp);
735 session->Unref();
736 done(s);
737 },
738 std::move(done)));
739 }
740
741 } // end namespace tensorflow
742