1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Master implements the service MasterService.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31
32 #include "tensorflow/core/distributed_runtime/master.h"
33
34 #include <unordered_set>
35 #include <vector>
36
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/protobuf/cluster.pb.h"
53 #include "tensorflow/core/protobuf/master.pb.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/util/device_name_utils.h"
57
58 namespace tensorflow {
59
60 namespace {
61 const char* const kGrpcProtocol = "grpc://";
62 } // namespace
63
Master(MasterEnv * env,double session_gc_seconds)64 Master::Master(MasterEnv* env, double session_gc_seconds)
65 : env_(env),
66 last_1000_steps_(1000),
67 step_count_(0),
68 session_gc_seconds_(session_gc_seconds),
69 recent_request_ids_(10000) {
70 // Right now, a master service must be co-located with a device.
71 // Otherwise, fetches do not work.
72 CHECK(!env->local_devices.empty());
73
74 if (session_gc_seconds_ > 0.0) {
75 gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
76 [this]() { GC(); });
77 } else {
78 gc_thread_ = nullptr;
79 }
80 }
81
~Master()82 Master::~Master() {
83 if (gc_thread_) {
84 mutex_lock l(mu_);
85 shutdown_ = true;
86 shutdown_cv_.notify_all();
87 delete gc_thread_;
88 }
89 }
90
GC()91 void Master::GC() {
92 Env* env = Env::Default();
93 while (true) {
94 mutex_lock l(mu_);
95 const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
96 WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
97 if (shutdown_) {
98 break;
99 }
100 std::vector<string> handles;
101 const int64_t num_micros =
102 static_cast<int64>(session_gc_seconds_ * 1000000);
103 for (const auto& entry : sessions_) {
104 int64_t lat = entry.second->last_access_time_usec();
105 if (static_cast<int64>(env->NowMicros()) - lat > num_micros) {
106 handles.push_back(entry.first);
107 auto* sess = entry.second;
108 SchedClosure([this, sess]() {
109 LOG(WARNING) << "GC session " << sess->handle() << " after "
110 << session_gc_seconds_ << " seconds. "
111 << "Note that if you are starting multiple replicas "
112 << "on a staggered delay, session_gc_seconds may need "
113 << "to be raised.";
114 sess->GarbageCollect();
115 });
116 }
117 }
118 for (const auto& handle : handles) sessions_.erase(handle);
119 }
120 }
121
FindMasterSession(const string & handle)122 MasterSession* Master::FindMasterSession(const string& handle) {
123 MasterSession* session = nullptr;
124 {
125 mutex_lock l(mu_);
126 session = gtl::FindPtrOrNull(sessions_, handle);
127 if (session != nullptr) {
128 session->Ref();
129 }
130 }
131 return session;
132 }
133
134 class DeviceFinder {
135 public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)136 static Status GetRemoteDevices(
137 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
138 WorkerCacheInterface* worker_cache,
139 std::vector<std::unique_ptr<Device>>* out_remote) {
140 DeviceFinder finder(device_filters, env, worker_cache);
141 finder.Start();
142 TF_RETURN_IF_ERROR(finder.Wait());
143 finder.GetRemoteDevices(env->local_devices, out_remote);
144 return Status::OK();
145 }
146
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)147 static void GetRemoteWorkers(
148 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
149 WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
150 DeviceFinder finder(device_filters, env, worker_cache);
151 *workers = finder.targets_;
152 }
153
154 private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)155 explicit DeviceFinder(
156 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
157 WorkerCacheInterface* worker_cache)
158 : env_(env), worker_cache_(worker_cache) {
159 CHECK(worker_cache) << "Worker cache was null!";
160 auto process_filter = [this](const string& filter) {
161 DeviceNameUtils::ParsedName parsed;
162 if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
163 filters_.push_back(parsed);
164 } else {
165 LOG(FATAL) << "Skipping invalid filter: " << filter;
166 }
167 };
168 for (const string& filter : device_filters) {
169 process_filter(filter);
170 }
171 // Enumerates all known workers' target. A target name is a
172 // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
173 if (filters_.empty()) {
174 // If no filters were specified, we list all known workers in
175 // `worker_cache`.
176 std::vector<string> workers;
177 worker_cache->ListWorkers(&workers);
178 std::swap(workers, targets_);
179 } else {
180 // When applying filters, we must include the local worker, even if it
181 // does not match any of the filters.
182 CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
183 const string& local_device_name = env_->local_devices[0]->name();
184 DeviceNameUtils::ParsedName local_parsed_name;
185 CHECK(DeviceNameUtils::ParseFullName(local_device_name,
186 &local_parsed_name));
187 bool all_filters_have_job = true;
188 std::unordered_set<string> filter_job_names({local_parsed_name.job});
189 for (const DeviceNameUtils::ParsedName& filter : filters_) {
190 all_filters_have_job = all_filters_have_job && filter.has_job;
191 if (filter.has_job) {
192 filter_job_names.insert(filter.job);
193 }
194 }
195
196 std::vector<string> workers;
197 if (all_filters_have_job) {
198 // If all of the device filters have a job specified, then we only need
199 // to list the workers in the jobs named in the filter, because a worker
200 // in any other job would not match any filter.
201 for (const string& job_name : filter_job_names) {
202 VLOG(2) << "Selectively listing workers in job: " << job_name;
203 std::vector<string> workers_in_job;
204 worker_cache->ListWorkersInJob(job_name, &workers_in_job);
205 workers.insert(workers.end(), workers_in_job.begin(),
206 workers_in_job.end());
207 }
208 } else {
209 // If any of the device filters does not have a job specified, then we
210 // must list the workers from all jobs.
211 VLOG(2) << "Listing workers in all jobs because some device "
212 << "filter has no job specified. Filters were:";
213 if (device_filters.empty()) {
214 VLOG(2) << "- <NO FILTERS>";
215 } else {
216 for (const string& filter : device_filters) {
217 VLOG(2) << "- " << filter;
218 }
219 }
220 worker_cache->ListWorkers(&workers);
221 }
222 for (const string& name : workers) {
223 if (MatchFilters(name) ||
224 DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
225 targets_.push_back(name);
226 }
227 }
228 }
229 seen_targets_.assign(targets_.size(), false);
230 }
231
~DeviceFinder()232 ~DeviceFinder() {
233 for (Device* dev : found_) delete dev;
234 }
235
Start()236 void Start() {
237 {
238 mutex_lock l(mu_);
239 num_pending_ = targets_.size();
240 if (num_pending_ == 0) {
241 pending_zero_.notify_all();
242 }
243 }
244 // Talk to all workers to get the list of available devices.
245 using std::placeholders::_1;
246 using std::placeholders::_2;
247 for (size_t i = 0; i < targets_.size(); ++i) {
248 // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
249 // never be called.
250 NewRemoteDevices(env_->env, worker_cache_, targets_[i],
251 std::bind(&ME::WhenFound, this, i, _1, _2));
252 }
253 }
254
255 // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
256 // to hear from workers, log a list of the workers who have not
257 // responded.
258 const int32 kLoggingPeriodMs = 10 * 1000;
259
Wait()260 Status Wait() {
261 mutex_lock l(mu_);
262 // TODO(mrry): Propagate a timeout here, since `num_pending_` may
263 // never become zero.
264 while (num_pending_ != 0) {
265 pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
266 if (num_pending_ != 0) {
267 for (size_t i = 0; i < targets_.size(); ++i) {
268 if (!seen_targets_[i]) {
269 LOG(INFO)
270 << "CreateSession still waiting for response from worker: "
271 << targets_[i];
272 }
273 }
274 }
275 }
276 return status_;
277 }
278
279 // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)280 void GetRemoteDevices(const std::vector<Device*>& local,
281 std::vector<std::unique_ptr<Device>>* remote) {
282 std::unordered_set<string> names(local.size());
283 for (Device* dev : local) names.insert(dev->name());
284 mutex_lock l(mu_);
285 for (Device* dev : found_) {
286 const string& name = dev->name();
287 if (names.insert(name).second && MatchFilters(name)) {
288 remote->push_back(std::unique_ptr<Device>(dev));
289 } else {
290 delete dev;
291 }
292 }
293 found_.clear();
294 }
295
296 typedef DeviceFinder ME;
297 const MasterEnv* env_;
298 WorkerCacheInterface* worker_cache_;
299 std::vector<DeviceNameUtils::ParsedName> filters_;
300
301 mutex mu_;
302 int num_pending_ TF_GUARDED_BY(mu_);
303 condition_variable pending_zero_;
304 std::vector<Device*> found_ TF_GUARDED_BY(mu_);
305 // List of targets to be contacted by this DeviceFinder. The
306 // respective `bool` in `seen_targets_` indicates whether we have
307 // heard from this target or not.
308 std::vector<string> targets_;
309 std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
310 Status status_;
311
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)312 void WhenFound(int target_index, const Status& s,
313 std::vector<Device*>* devices) {
314 mutex_lock l(mu_);
315 seen_targets_[target_index] = true;
316 if (!s.ok()) {
317 LOG(ERROR) << "CreateSession failed because worker "
318 << targets_[target_index] << " returned error: " << s;
319 status_.Update(s);
320 } else {
321 found_.insert(found_.end(), devices->begin(), devices->end());
322 devices->clear();
323 }
324 --num_pending_;
325 if (num_pending_ == 0) {
326 pending_zero_.notify_all();
327 }
328 }
329
330 // Returns true iff the set of devices allowed by 'x' intersects
331 // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)332 bool Intersects(const DeviceNameUtils::ParsedName& x,
333 const DeviceNameUtils::ParsedName& y) {
334 return (!x.has_job || !y.has_job || x.job == y.job) &&
335 (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
336 (!x.has_task || !y.has_task || x.task == y.task) &&
337 (!x.has_type || !y.has_type || x.type == y.type) &&
338 (!x.has_id || !y.has_id || x.id == y.id);
339 }
340
341 // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)342 bool MatchFilters(const string& name) {
343 if (filters_.empty()) return true;
344 DeviceNameUtils::ParsedName x;
345 if (DeviceNameUtils::ParseFullName(name, &x)) {
346 for (const auto& filter : filters_) {
347 if (Intersects(x, filter)) return true;
348 }
349 }
350 return false;
351 }
352
353 TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
354 };
355
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)356 void Master::CreateSession(const CreateSessionRequest* req,
357 CreateSessionResponse* resp, MyClosure done) {
358 SchedClosure([this, req, resp, done]() {
359 Status status;
360 WorkerCacheFactoryOptions worker_cache_factory_options;
361 string grpc_protocol("grpc");
362 worker_cache_factory_options.protocol = &grpc_protocol;
363 auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
364 status = ValidateExternalGraphDefSyntax(req->graph_def());
365 if (!status.ok()) return;
366
367 // The following 4 variables are set differently, depending on whether this
368 // session uses a client-provided clusterspec or not.
369 WorkerCacheInterface* worker_cache = nullptr;
370 // Note: worker_cache_ptr will be null except if this session is using a
371 // client-supplied ClusterDef (ClusterSpec propagation).
372 std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
373 std::unique_ptr<DeviceSet> device_set;
374 // TODO(saeta): Convert to std::make_unique when available.
375 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
376 new std::vector<std::unique_ptr<Device>>());
377
378 if (req->config().has_cluster_def()) {
379 worker_cache_factory_options.cluster_def = &req->config().cluster_def();
380
381 // Set the server_def's job_name and task_index fields.
382 string normalized_string;
383 string grpc_protocol(kGrpcProtocol);
384 if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
385 0) {
386 normalized_string =
387 req->target().substr(grpc_protocol.length(), string::npos);
388 } else {
389 normalized_string = req->target();
390 }
391 for (auto&& job : req->config().cluster_def().job()) {
392 for (auto&& task : job.tasks()) {
393 if (task.second == normalized_string) {
394 if (worker_cache_factory_options.job_name != nullptr) {
395 status = errors::InvalidArgument(
396 "Found multiple matching tasks that correspond to "
397 "to the master. Master target: '",
398 req->target(), "'. ClusterDef: ",
399 req->config().cluster_def().ShortDebugString());
400 LOG(ERROR) << status;
401 return;
402 }
403 if (env_->local_devices[0]->parsed_name().job == job.name() &&
404 env_->local_devices[0]->parsed_name().task == task.first) {
405 // TODO(b/37868888): Remove this limitation when resolved
406 status = errors::InvalidArgument(
407 "The ClusterSpec names the job and task index to be the same "
408 "names that were provided when the server booted. This is "
409 "currently not allowed. Job: ",
410 job.name(), ", task index: ", task.first);
411 return;
412 }
413 worker_cache_factory_options.job_name = &job.name();
414 worker_cache_factory_options.task_index = task.first;
415 }
416 }
417 }
418 worker_cache_factory_options.rpc_options = &req->config().rpc_options();
419 // Create the worker cache from the computed server_def.
420 status = env_->worker_cache_factory(worker_cache_factory_options,
421 &worker_cache);
422 if (!status.ok()) return;
423 worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
424 // Ping all the workers and build the list of devices that the
425 // session will use.
426 status =
427 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
428 worker_cache, remote_devices.get());
429 if (!status.ok()) return;
430 device_set.reset(new DeviceSet);
431 for (auto&& d : *remote_devices) {
432 device_set->AddDevice(d.get());
433 DeviceNameUtils::ParsedName name = d->parsed_name();
434 if (name.job == *worker_cache_factory_options.job_name &&
435 name.task == worker_cache_factory_options.task_index &&
436 name.type == "CPU" && name.id == 0) {
437 device_set->set_client_device(d.get());
438 }
439 }
440 } else {
441 worker_cache = env_->worker_cache;
442 // Ping all the workers and build the list of devices that the
443 // session will use.
444 status =
445 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
446 worker_cache, remote_devices.get());
447 if (!status.ok()) return;
448 device_set.reset(new DeviceSet);
449 for (auto&& d : *remote_devices) {
450 device_set->AddDevice(d.get());
451 }
452 int num_local_devices = 0;
453 for (Device* d : env_->local_devices) {
454 device_set->AddDevice(d);
455 if (num_local_devices == 0) {
456 // Uses the first local device as the client device.
457 device_set->set_client_device(d);
458 }
459 num_local_devices++;
460 }
461 }
462
463 CHECK(device_set->client_device()) << "No client device found. Missing "
464 << "CPU:0 device?";
465
466 SessionOptions options;
467 options.config = req->config();
468
469 std::vector<string> filtered_worker_list;
470 DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
471 worker_cache, &filtered_worker_list);
472
473 MasterSession* session = env_->master_session_factory(
474 options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
475 std::move(device_set), std::move(filtered_worker_list));
476
477 GraphDef* gdef =
478 const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
479
480 status = session->Create(std::move(*gdef), worker_cache_factory_options);
481 if (!status.ok()) {
482 session->Close().IgnoreError();
483 session->Unref();
484 return;
485 }
486 resp->set_session_handle(session->handle());
487 // Insert into the session map, which takes ownership of the session.
488 {
489 mutex_lock l(mu_);
490 CHECK(sessions_.insert({session->handle(), session}).second);
491 }
492 });
493 }
494
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)495 void Master::ExtendSession(const ExtendSessionRequest* req,
496 ExtendSessionResponse* resp, MyClosure done) {
497 auto session = FindMasterSession(req->session_handle());
498 if (session == nullptr) {
499 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
500 return;
501 }
502
503 SchedClosure([session, req, resp, done]() {
504 Status status = ValidateExternalGraphDefSyntax(req->graph_def());
505 if (status.ok()) {
506 status = session->Extend(req, resp);
507 }
508 session->Unref();
509 done(status);
510 });
511 }
512
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)513 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
514 PartialRunSetupResponse* resp, MyClosure done) {
515 Status s = recent_request_ids_.TrackUnique(req->request_id(),
516 "PartialRunSetup (Master)", *req);
517 if (!s.ok()) {
518 done(s);
519 return;
520 }
521 auto session = FindMasterSession(req->session_handle());
522 if (session == nullptr) {
523 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
524 return;
525 }
526
527 SchedClosure([session, req, resp, done]() {
528 Status s = session->PartialRunSetup(req, resp);
529 session->Unref();
530 done(s);
531 });
532 }
533
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)534 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
535 MutableRunStepResponseWrapper* resp, MyClosure done) {
536 Status s = recent_request_ids_.TrackUnique(req->request_id(),
537 "RunStep (Master)", req);
538 if (!s.ok()) {
539 done(s);
540 return;
541 }
542 auto start_time = env_->env->NowMicros();
543 auto session = FindMasterSession(req->session_handle());
544 if (session == nullptr) {
545 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
546 return;
547 }
548
549 SchedClosure([this, start_time, session, opts, req, resp, done]() {
550 Status status = session->Run(opts, *req, resp);
551 session->Unref();
552 uint64 done_time = env_->env->NowMicros();
553 done(status);
554 mutex_lock l(mu_);
555 last_1000_steps_.AddValue((done_time - start_time) / 1e9);
556 ++step_count_;
557 });
558 }
559
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)560 void Master::CloseSession(const CloseSessionRequest* req,
561 CloseSessionResponse* resp, MyClosure done) {
562 MasterSession* session = nullptr;
563 {
564 mu_.lock();
565 auto iter = sessions_.find(req->session_handle());
566 if (iter == sessions_.end()) {
567 mu_.unlock();
568 done(errors::Aborted(
569 "Session ", req->session_handle(),
570 " is not found. Possibly, this master has restarted."));
571 return;
572 }
573 // NOTE(mrry): One reference to the session is transferred from
574 // `sessions_[req->session_handle()]` to `session`.
575 session = iter->second;
576 sessions_.erase(iter);
577 mu_.unlock();
578 }
579
580 // Session Close() blocks on thread shutdown. Therefore, we need to
581 // delete it in non-critical thread.
582 SchedClosure([session, done]() {
583 Status s = session->Close();
584 session->Unref();
585 done(s);
586 });
587 }
588
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)589 void Master::ListDevices(const ListDevicesRequest* req,
590 ListDevicesResponse* resp, MyClosure done) {
591 SchedClosure([this, req, resp, done]() {
592 if (!req->session_handle().empty()) {
593 auto session = FindMasterSession(req->session_handle());
594 if (session == nullptr) {
595 done(errors::InvalidArgument(
596 "Session ", req->session_handle(),
597 " is not found. Possibly, this master has restarted."));
598 return;
599 }
600 core::ScopedUnref ref(session);
601 Status s = session->ListDevices(resp);
602 done(s);
603 return;
604 }
605 std::vector<std::unique_ptr<Device>> remote_devices;
606 Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
607 &remote_devices);
608 if (s.ok()) {
609 for (Device* dev : env_->local_devices) {
610 *(resp->add_local_device()) = dev->attributes();
611 }
612 for (auto&& dev : remote_devices) {
613 *(resp->add_remote_device()) = dev->attributes();
614 }
615 }
616 done(s);
617 });
618 }
619
CleanupWorkers(const ResetRequest & reset)620 void Master::CleanupWorkers(const ResetRequest& reset) {
621 std::vector<string> worker_names;
622 DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
623 env_->worker_cache, &worker_names);
624 if (!worker_names.empty()) {
625 const int num_workers = worker_names.size();
626 std::vector<Notification> n(num_workers);
627 CleanupAllRequest req;
628 (*req.mutable_container()) = reset.container();
629 std::vector<CleanupAllResponse> resp(num_workers);
630 int c = 0;
631 for (int i = 0; i < num_workers; ++i) {
632 const string& worker_name = worker_names[i];
633 auto worker = env_->worker_cache->GetOrCreateWorker(worker_name);
634 if (worker) {
635 worker->CleanupAllAsync(
636 &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
637 TF_CHECK_OK(s);
638 env_->worker_cache->ReleaseWorker(worker_name, worker);
639 n[c].Notify();
640 });
641 } else {
642 n[c].Notify();
643 }
644 ++c;
645 }
646 for (size_t i = 0; i < n.size(); ++i) {
647 n[i].WaitForNotification();
648 }
649 }
650 }
651
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)652 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
653 MyClosure done) {
654 // Vector to hold the session pointers present in the sessions_
655 // (string->Session*) map.
656 std::vector<MasterSession*> sessions_to_close;
657 {
658 mutex_lock l(mu_);
659 // NOTE(mrry): Transfer one reference to each session from the
660 // `sessions_` map to the `sessions_to_close` vector.
661 for (const auto& entry : sessions_) {
662 sessions_to_close.push_back(entry.second);
663 }
664 sessions_.clear();
665 }
666
667 CleanupWorkers(*req);
668
669 SchedClosure([sessions_to_close, done]() {
670 Status s;
671 for (MasterSession* session : sessions_to_close) {
672 s.Update(session->Close());
673 session->Unref();
674 }
675 done(s);
676 });
677 }
678
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)679 void Master::MakeCallable(const MakeCallableRequest* req,
680 MakeCallableResponse* resp, MyClosure done) {
681 Status s = recent_request_ids_.TrackUnique(req->request_id(),
682 "MakeCallable (Master)", *req);
683 if (!s.ok()) {
684 done(s);
685 return;
686 }
687 auto session = FindMasterSession(req->session_handle());
688 if (session == nullptr) {
689 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
690 return;
691 }
692
693 SchedClosure([session, req, resp, done = std::move(done)]() {
694 Status s = session->MakeCallable(*req, resp);
695 session->Unref();
696 done(s);
697 });
698 }
699
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)700 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
701 RunCallableResponse* resp, MyClosure done) {
702 Status s = recent_request_ids_.TrackUnique(req->request_id(),
703 "RunCallable (Master)", *req);
704 if (!s.ok()) {
705 done(s);
706 return;
707 }
708 auto session = FindMasterSession(req->session_handle());
709 if (session == nullptr) {
710 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
711 return;
712 }
713
714 SchedClosure([session, opts, req, resp, done = std::move(done)]() {
715 Status s = session->RunCallable(opts, *req, resp);
716 session->Unref();
717 done(s);
718 });
719 }
720
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)721 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
722 ReleaseCallableResponse* resp, MyClosure done) {
723 auto session = FindMasterSession(req->session_handle());
724 if (session == nullptr) {
725 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
726 return;
727 }
728
729 SchedClosure([session, req, resp, done = std::move(done)]() {
730 Status s = session->ReleaseCallable(*req, resp);
731 session->Unref();
732 done(s);
733 });
734 }
735
736 } // end namespace tensorflow
737