1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/request_id.h"
28 #include "tensorflow/core/distributed_runtime/scheduler.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/framework/allocation_description.pb.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/cost_graph.pb.h"
34 #include "tensorflow/core/framework/graph_def_util.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_description.pb.h"
40 #include "tensorflow/core/graph/graph_partition.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/blocking_counter.h"
43 #include "tensorflow/core/lib/core/notification.h"
44 #include "tensorflow/core/lib/core/refcount.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/tracing.h"
59 #include "tensorflow/core/public/session_options.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61
62 namespace tensorflow {
63
64 // MasterSession wraps ClientGraph in a reference counted object.
65 // This way, MasterSession can clear up the cache mapping Run requests to
66 // compiled graphs while the compiled graph is still being used.
67 //
68 // TODO(zhifengc): Cleanup this class. It's becoming messy.
69 class MasterSession::ReffedClientGraph : public core::RefCounted {
70 public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)71 ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
72 std::unique_ptr<ClientGraph> client_graph,
73 const SessionOptions& session_opts,
74 const StatsPublisherFactory& stats_publisher_factory,
75 bool is_partial, WorkerCacheInterface* worker_cache,
76 bool should_deregister)
77 : session_handle_(handle),
78 bg_opts_(bopts),
79 client_graph_before_register_(std::move(client_graph)),
80 session_opts_(session_opts),
81 is_partial_(is_partial),
82 callable_opts_(bopts.callable_options),
83 worker_cache_(worker_cache),
84 should_deregister_(should_deregister),
85 collective_graph_key_(
86 client_graph_before_register_->collective_graph_key) {
87 VLOG(1) << "Created ReffedClientGraph for node with "
88 << client_graph_before_register_->graph.num_node_ids();
89
90 stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
91
92 // Initialize a name to node map for processing device stats.
93 for (Node* n : client_graph_before_register_->graph.nodes()) {
94 name_to_node_details_.emplace(
95 n->name(),
96 NodeDetails(n->type_string(),
97 strings::StrCat(
98 "(", absl::StrJoin(n->requested_inputs(), ", "))));
99 }
100 }
101
~ReffedClientGraph()102 ~ReffedClientGraph() override {
103 if (should_deregister_) {
104 DeregisterPartitions();
105 } else {
106 for (Part& part : partitions_) {
107 worker_cache_->ReleaseWorker(part.name, part.worker);
108 }
109 }
110 }
111
callable_options()112 const CallableOptions& callable_options() { return callable_opts_; }
113
build_graph_options()114 const BuildGraphOptions& build_graph_options() { return bg_opts_; }
115
collective_graph_key()116 int64 collective_graph_key() { return collective_graph_key_; }
117
GetProfileHandler(uint64 step,int64 execution_count,const RunOptions & ropts)118 std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
119 int64 execution_count,
120 const RunOptions& ropts) {
121 return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
122 }
123
get_and_increment_execution_count()124 int64 get_and_increment_execution_count() {
125 return execution_count_.fetch_add(1);
126 }
127
128 // Turn RPC logging on or off, both at the WorkerCache used by this
129 // master process, and at each remote worker in use for the current
130 // partitions.
SetRPCLogging(bool active)131 void SetRPCLogging(bool active) {
132 worker_cache_->SetLogging(active);
133 // Logging is a best-effort activity, so we make async calls to turn
134 // it on/off and don't make use of the responses.
135 for (auto& p : partitions_) {
136 LoggingRequest* req = new LoggingRequest;
137 if (active) {
138 req->set_enable_rpc_logging(true);
139 } else {
140 req->set_disable_rpc_logging(true);
141 }
142 LoggingResponse* resp = new LoggingResponse;
143 Ref();
144 p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
145 delete req;
146 delete resp;
147 // ReffedClientGraph owns p.worker so we need to hold a ref to
148 // ensure that the method doesn't attempt to access p.worker after
149 // ReffedClient graph has deleted it.
150 // TODO(suharshs): Simplify this ownership model.
151 Unref();
152 });
153 }
154 }
155
156 // Retrieve all RPC logs data accumulated for the current step, both
157 // from the local WorkerCache in use by this master process and from
158 // all the remote workers executing the remote partitions.
RetrieveLogs(int64 step_id,StepStats * ss)159 void RetrieveLogs(int64 step_id, StepStats* ss) {
160 // Get the local data first, because it sets *ss without merging.
161 worker_cache_->RetrieveLogs(step_id, ss);
162
163 // Then merge in data from all the remote workers.
164 LoggingRequest req;
165 req.add_fetch_step_id(step_id);
166 int waiting_for = partitions_.size();
167 if (waiting_for > 0) {
168 mutex scoped_mu;
169 BlockingCounter all_done(waiting_for);
170 for (auto& p : partitions_) {
171 LoggingResponse* resp = new LoggingResponse;
172 p.worker->LoggingAsync(
173 &req, resp,
174 [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
175 {
176 mutex_lock l(scoped_mu);
177 if (s.ok()) {
178 for (auto& lss : resp->step()) {
179 if (step_id != lss.step_id()) {
180 LOG(ERROR) << "Wrong step_id in LoggingResponse";
181 continue;
182 }
183 ss->MergeFrom(lss.step_stats());
184 }
185 }
186 delete resp;
187 }
188 // Must not decrement all_done until out of critical section where
189 // *ss is updated.
190 all_done.DecrementCount();
191 });
192 }
193 all_done.Wait();
194 }
195 }
196
197 // Local execution methods.
198
199 // Partitions the graph into subgraphs and registers them on
200 // workers.
201 Status RegisterPartitions(PartitionOptions popts);
202
203 // Runs one step of all partitions.
204 Status RunPartitions(const MasterEnv* env, int64 step_id,
205 int64 execution_count, PerStepState* pss,
206 CallOptions* opts, const RunStepRequestWrapper& req,
207 MutableRunStepResponseWrapper* resp,
208 CancellationManager* cm, const bool is_last_partial_run);
209 Status RunPartitions(const MasterEnv* env, int64 step_id,
210 int64 execution_count, PerStepState* pss,
211 CallOptions* call_opts, const RunCallableRequest& req,
212 RunCallableResponse* resp, CancellationManager* cm);
213
214 // Calls workers to cleanup states for the step "step_id". Calls
215 // `done` when all cleanup RPCs have completed.
216 void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
217
218 // Post-processing of any runtime statistics gathered during execution.
219 void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
220 const RunOptions& options, RunMetadata* resp);
221 void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
222 bool is_rpc);
223 // Checks that the requested fetches can be computed from the provided feeds.
224 Status CheckFetches(const RunStepRequestWrapper& req,
225 const RunState* run_state,
226 GraphExecutionState* execution_state);
227
228 private:
229 const string session_handle_;
230 const BuildGraphOptions bg_opts_;
231
232 // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
233 std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
234 const SessionOptions session_opts_;
235 const bool is_partial_;
236 const CallableOptions callable_opts_;
237 WorkerCacheInterface* const worker_cache_; // Not owned.
238
239 struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails240 explicit NodeDetails(string type_string, string detail_text)
241 : type_string(std::move(type_string)),
242 detail_text(std::move(detail_text)) {}
243 const string type_string;
244 const string detail_text;
245 };
246 std::unordered_map<string, NodeDetails> name_to_node_details_;
247
248 const bool should_deregister_;
249 const int64 collective_graph_key_;
250 std::atomic<int64> execution_count_ = {0};
251
252 // Graph partitioned into per-location subgraphs.
253 struct Part {
254 // Worker name.
255 string name;
256
257 // Maps feed names to rendezvous keys. Empty most of the time.
258 std::unordered_map<string, string> feed_key;
259
260 // Maps rendezvous keys to fetch names. Empty most of the time.
261 std::unordered_map<string, string> key_fetch;
262
263 // The interface to the worker. Owned.
264 WorkerInterface* worker = nullptr;
265
266 // After registration with the worker, graph_handle identifies
267 // this partition on the worker.
268 string graph_handle;
269
Parttensorflow::MasterSession::ReffedClientGraph::Part270 Part() : feed_key(3), key_fetch(3) {}
271 };
272
273 // partitions_ is immutable after RegisterPartitions() call
274 // finishes. RunPartitions() can access partitions_ safely without
275 // acquiring locks.
276 std::vector<Part> partitions_;
277
278 mutable mutex mu_;
279
280 // Partition initialization and registration only needs to happen
281 // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
282 // indicates the initialization is ongoing.
283 Notification init_done_;
284
285 // init_result_ remembers the initialization error if any.
286 Status init_result_ TF_GUARDED_BY(mu_);
287
288 std::unique_ptr<StatsPublisherInterface> stats_publisher_;
289
DetailText(const NodeDetails & details,const NodeExecStats & stats)290 string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
291 int64 tot = 0;
292 for (auto& no : stats.output()) {
293 tot += no.tensor_description().allocation_description().requested_bytes();
294 }
295 string bytes;
296 if (tot >= 0.1 * 1048576.0) {
297 bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
298 }
299 return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
300 details.detail_text);
301 }
302
303 // Send/Recv nodes that are the result of client-added
304 // feeds and fetches must be tracked so that the tensors
305 // can be added to the local rendezvous.
306 static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
307 const PartitionOptions& popts);
308
309 // The actual graph partitioning and registration implementation.
310 Status DoBuildPartitions(
311 PartitionOptions popts, ClientGraph* client_graph,
312 std::unordered_map<string, GraphDef>* out_partitions);
313 Status DoRegisterPartitions(
314 const PartitionOptions& popts,
315 std::unordered_map<string, GraphDef> graph_partitions);
316
317 // Prepares a number of calls to workers. One call per partition.
318 // This is a generic method that handles Run, PartialRun, and RunCallable.
319 template <class FetchListType, class ClientRequestType,
320 class ClientResponseType>
321 Status RunPartitionsHelper(
322 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
323 const FetchListType& fetches, const MasterEnv* env, int64 step_id,
324 int64 execution_count, PerStepState* pss, CallOptions* call_opts,
325 const ClientRequestType& req, ClientResponseType* resp,
326 CancellationManager* cm, bool is_last_partial_run);
327
328 // Deregisters the partitions on the workers. Called in the
329 // destructor and does not wait for the rpc completion.
330 void DeregisterPartitions();
331
332 TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
333 };
334
RegisterPartitions(PartitionOptions popts)335 Status MasterSession::ReffedClientGraph::RegisterPartitions(
336 PartitionOptions popts) {
337 { // Ensure register once.
338 mu_.lock();
339 if (client_graph_before_register_) {
340 // The `ClientGraph` is no longer needed after partitions are registered.
341 // Since it can account for a large amount of memory, we consume it here,
342 // and it will be freed after concluding with registration.
343
344 std::unique_ptr<ClientGraph> client_graph;
345 std::swap(client_graph_before_register_, client_graph);
346 mu_.unlock();
347 std::unordered_map<string, GraphDef> graph_defs;
348 popts.flib_def = client_graph->flib_def.get();
349 Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
350 if (s.ok()) {
351 // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
352 // valid after the call to DoRegisterPartitions begins, so
353 // `stats_publisher_` must make a copy if it wants to retain the
354 // GraphDef objects.
355 std::vector<const GraphDef*> graph_defs_for_publishing;
356 graph_defs_for_publishing.reserve(partitions_.size());
357 for (const auto& name_def : graph_defs) {
358 graph_defs_for_publishing.push_back(&name_def.second);
359 }
360 stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
361 s = DoRegisterPartitions(popts, std::move(graph_defs));
362 }
363 mu_.lock();
364 init_result_ = s;
365 init_done_.Notify();
366 } else {
367 mu_.unlock();
368 init_done_.WaitForNotification();
369 mu_.lock();
370 }
371 const Status result = init_result_;
372 mu_.unlock();
373 return result;
374 }
375 }
376
SplitByWorker(const Node * node)377 static string SplitByWorker(const Node* node) {
378 string task;
379 string device;
380 CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
381 &device))
382 << "node: " << node->name() << " dev: " << node->assigned_device_name();
383 return task;
384 }
385
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)386 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
387 Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
388 for (int i = 0; i < graph_def.node_size(); ++i) {
389 const NodeDef& ndef = graph_def.node(i);
390 const bool is_recv = ndef.op() == "_Recv";
391 const bool is_send = ndef.op() == "_Send";
392
393 if (is_recv || is_send) {
394 // Only send/recv nodes that were added as feeds and fetches
395 // (client-terminated) should be tracked. Other send/recv nodes
396 // are for transferring data between partitions / memory spaces.
397 bool client_terminated;
398 TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
399 if (client_terminated) {
400 string name;
401 TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
402 string send_device;
403 TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
404 string recv_device;
405 TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
406 uint64 send_device_incarnation;
407 TF_CHECK_OK(
408 GetNodeAttr(ndef, "send_device_incarnation",
409 reinterpret_cast<int64*>(&send_device_incarnation)));
410 const string& key =
411 Rendezvous::CreateKey(send_device, send_device_incarnation,
412 recv_device, name, FrameAndIter(0, 0));
413
414 if (is_recv) {
415 part->feed_key.insert({name, key});
416 } else {
417 part->key_fetch.insert({key, name});
418 }
419 }
420 }
421 }
422 }
423
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)424 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
425 PartitionOptions popts, ClientGraph* client_graph,
426 std::unordered_map<string, GraphDef>* out_partitions) {
427 if (popts.need_to_record_start_times) {
428 CostModel cost_model(true);
429 cost_model.InitFromGraph(client_graph->graph);
430 // TODO(yuanbyu): Use the real cost model.
431 // execution_state_->MergeFromGlobal(&cost_model);
432 SlackAnalysis sa(&client_graph->graph, &cost_model);
433 sa.ComputeAsap(&popts.start_times);
434 }
435
436 // Partition the graph.
437 return Partition(popts, &client_graph->graph, out_partitions);
438 }
439
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)440 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
441 const PartitionOptions& popts,
442 std::unordered_map<string, GraphDef> graph_partitions) {
443 partitions_.reserve(graph_partitions.size());
444 Status s;
445 for (auto& name_def : graph_partitions) {
446 partitions_.emplace_back();
447 Part* part = &partitions_.back();
448 part->name = name_def.first;
449 TrackFeedsAndFetches(part, name_def.second, popts);
450 part->worker = worker_cache_->GetOrCreateWorker(part->name);
451 if (part->worker == nullptr) {
452 s = errors::NotFound("worker ", part->name);
453 break;
454 }
455 }
456 if (!s.ok()) {
457 for (Part& part : partitions_) {
458 worker_cache_->ReleaseWorker(part.name, part.worker);
459 part.worker = nullptr;
460 }
461 return s;
462 }
463 struct Call {
464 RegisterGraphRequest req;
465 RegisterGraphResponse resp;
466 Status status;
467 };
468 const int num = partitions_.size();
469 gtl::InlinedVector<Call, 4> calls(num);
470 BlockingCounter done(num);
471 for (int i = 0; i < num; ++i) {
472 const Part& part = partitions_[i];
473 Call* c = &calls[i];
474 c->req.set_session_handle(session_handle_);
475 c->req.set_create_worker_session_called(!should_deregister_);
476 c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
477 StripDefaultAttributes(*OpRegistry::Global(),
478 c->req.mutable_graph_def()->mutable_node());
479 *c->req.mutable_config_proto() = session_opts_.config;
480 *c->req.mutable_graph_options() = session_opts_.config.graph_options();
481 *c->req.mutable_debug_options() =
482 callable_opts_.run_options().debug_options();
483 c->req.set_collective_graph_key(collective_graph_key_);
484 VLOG(2) << "Register " << c->req.graph_def().DebugString();
485 auto cb = [c, &done](const Status& s) {
486 c->status = s;
487 done.DecrementCount();
488 };
489 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
490 }
491 done.Wait();
492 for (int i = 0; i < num; ++i) {
493 Call* c = &calls[i];
494 s.Update(c->status);
495 partitions_[i].graph_handle = c->resp.graph_handle();
496 }
497 return s;
498 }
499
500 namespace {
501 // Helper class to manage "num" parallel RunGraph calls.
502 class RunManyGraphs {
503 public:
RunManyGraphs(int num)504 explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
505
~RunManyGraphs()506 ~RunManyGraphs() {}
507
508 // Returns the index-th call.
509 struct Call {
510 CallOptions opts;
511 const string* worker_name;
512 std::atomic<bool> done{false};
513 std::unique_ptr<MutableRunGraphRequestWrapper> req;
514 std::unique_ptr<MutableRunGraphResponseWrapper> resp;
515 };
get(int index)516 Call* get(int index) { return &calls_[index]; }
517
518 // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)519 void WhenDone(int index, const Status& s) {
520 TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
521 Call* call = get(index);
522 call->done = true;
523 auto resp = call->resp.get();
524 if (resp->status_code() != error::Code::OK) {
525 // resp->status_code will only be non-OK if s.ok().
526 mutex_lock l(mu_);
527 ReportBadStatus(Status(resp->status_code(),
528 strings::StrCat("From ", *call->worker_name, ":\n",
529 resp->status_error_message())));
530 } else if (!s.ok()) {
531 mutex_lock l(mu_);
532 ReportBadStatus(
533 Status(s.code(), strings::StrCat("From ", *call->worker_name, ":\n",
534 s.error_message())));
535 }
536 pending_.DecrementCount();
537 }
538
StartCancel()539 void StartCancel() {
540 mutex_lock l(mu_);
541 ReportBadStatus(errors::Cancelled("RunManyGraphs"));
542 }
543
Wait()544 void Wait() {
545 // Check the error status every 60 seconds in other to print a log message
546 // in the event of a hang.
547 const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
548 while (true) {
549 if (pending_.WaitFor(kCheckErrorPeriod)) {
550 return;
551 }
552 if (!status().ok()) {
553 break;
554 }
555 }
556
557 // The step has failed. Wait for another 60 seconds before diagnosing a
558 // hang.
559 DCHECK(!status().ok());
560 if (pending_.WaitFor(kCheckErrorPeriod)) {
561 return;
562 }
563 LOG(ERROR)
564 << "RunStep still blocked after 60 seconds. Failed with error status: "
565 << status();
566 for (const Call& call : calls_) {
567 if (!call.done) {
568 LOG(ERROR) << "- No response from RunGraph call to worker: "
569 << *call.worker_name;
570 }
571 }
572 pending_.Wait();
573 }
574
status() const575 Status status() const {
576 mutex_lock l(mu_);
577 // Concat status objects in this StatusGroup to get the aggregated status,
578 // as each status in status_group_ is already summarized status.
579 return status_group_.as_concatenated_status();
580 }
581
582 private:
583 gtl::InlinedVector<Call, 4> calls_;
584
585 BlockingCounter pending_;
586 mutable mutex mu_;
587 StatusGroup status_group_ TF_GUARDED_BY(mu_);
588 bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
589
ReportBadStatus(const Status & s)590 void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
591 VLOG(1) << "Master received error status " << s;
592 if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
593 // Only start cancelling other workers upon receiving a non-derived
594 // error
595 cancel_issued_ = true;
596
597 VLOG(1) << "Master received error report. Cancelling remaining workers.";
598 for (Call& call : calls_) {
599 call.opts.StartCancel();
600 }
601 }
602
603 status_group_.Update(s);
604 }
605
606 TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
607 };
608
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)609 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
610 MutableRunGraphRequestWrapper* worker_req,
611 size_t index, const string& send_key) {
612 return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
613 }
614
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)615 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
616 MutableRunGraphRequestWrapper* worker_req,
617 size_t index, const string& send_key) {
618 return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
619 }
620
621 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
622 // in-process messages.
623 struct RunCallableResponseWrapper {
624 RunCallableResponse* resp; // Not owned.
625 std::unordered_map<string, TensorProto> fetch_key_to_protos;
626
mutable_metadatatensorflow::__anon4940402c0411::RunCallableResponseWrapper627 RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
628
AddTensorFromRunGraphResponsetensorflow::__anon4940402c0411::RunCallableResponseWrapper629 Status AddTensorFromRunGraphResponse(
630 const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
631 size_t index) {
632 // TODO(b/74355905): Add a specialized implementation that avoids
633 // copying the tensor into the RunCallableResponse when at least
634 // two of the {client, master, worker} are in the same process.
635 return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
636 }
637 };
638 } // namespace
639
640 template <class FetchListType, class ClientRequestType,
641 class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)642 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
643 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
644 const FetchListType& fetches, const MasterEnv* env, int64 step_id,
645 int64 execution_count, PerStepState* pss, CallOptions* call_opts,
646 const ClientRequestType& req, ClientResponseType* resp,
647 CancellationManager* cm, bool is_last_partial_run) {
648 // Collect execution cost stats on a smoothly decreasing frequency.
649 ExecutorOpts exec_opts;
650 if (pss->report_tensor_allocations_upon_oom) {
651 exec_opts.set_report_tensor_allocations_upon_oom(true);
652 }
653 if (pss->collect_costs) {
654 exec_opts.set_record_costs(true);
655 }
656 if (pss->collect_timeline) {
657 exec_opts.set_record_timeline(true);
658 }
659 if (pss->collect_rpcs) {
660 SetRPCLogging(true);
661 }
662 if (pss->collect_partition_graphs) {
663 exec_opts.set_record_partition_graphs(true);
664 }
665 if (pss->collect_costs || pss->collect_timeline) {
666 pss->step_stats.resize(partitions_.size());
667 }
668
669 const int num = partitions_.size();
670 RunManyGraphs calls(num);
671
672 for (int i = 0; i < num; ++i) {
673 const Part& part = partitions_[i];
674 RunManyGraphs::Call* c = calls.get(i);
675 c->worker_name = &part.name;
676 c->req.reset(part.worker->CreateRunGraphRequest());
677 c->resp.reset(part.worker->CreateRunGraphResponse());
678 if (is_partial_) {
679 c->req->set_is_partial(is_partial_);
680 c->req->set_is_last_partial_run(is_last_partial_run);
681 }
682 c->req->set_session_handle(session_handle_);
683 c->req->set_create_worker_session_called(!should_deregister_);
684 c->req->set_graph_handle(part.graph_handle);
685 c->req->set_step_id(step_id);
686 *c->req->mutable_exec_opts() = exec_opts;
687 c->req->set_store_errors_in_response_body(true);
688 c->req->set_request_id(GetUniqueRequestId());
689 // If any feeds are provided, send the feed values together
690 // in the RunGraph request.
691 // In the partial case, we only want to include feeds provided in the req.
692 // In the non-partial case, all feeds in the request are in the part.
693 // We keep these as separate paths for now, to ensure we aren't
694 // inadvertently slowing down the normal run path.
695 if (is_partial_) {
696 for (const auto& name_index : feeds) {
697 const auto iter = part.feed_key.find(string(name_index.first));
698 if (iter == part.feed_key.end()) {
699 // The provided feed must be for a different partition.
700 continue;
701 }
702 const string& key = iter->second;
703 TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
704 name_index.second, key));
705 }
706 // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
707 // For now, we just iterate through partitions to find the matching key.
708 for (const string& req_fetch : fetches) {
709 for (const auto& key_fetch : part.key_fetch) {
710 if (key_fetch.second == req_fetch) {
711 c->req->add_recv_key(key_fetch.first);
712 break;
713 }
714 }
715 }
716 } else {
717 for (const auto& feed_key : part.feed_key) {
718 const string& feed = feed_key.first;
719 const string& key = feed_key.second;
720 auto iter = feeds.find(feed);
721 if (iter == feeds.end()) {
722 return errors::Internal("No feed index found for feed: ", feed);
723 }
724 const int64 feed_index = iter->second;
725 TF_RETURN_IF_ERROR(
726 AddSendFromClientRequest(req, c->req.get(), feed_index, key));
727 }
728 for (const auto& key_fetch : part.key_fetch) {
729 const string& key = key_fetch.first;
730 c->req->add_recv_key(key);
731 }
732 }
733 }
734
735 // Issues RunGraph calls.
736 for (int i = 0; i < num; ++i) {
737 const Part& part = partitions_[i];
738 RunManyGraphs::Call* call = calls.get(i);
739 TRACEPRINTF("Partition %d %s", i, part.name.c_str());
740 part.worker->RunGraphAsync(
741 &call->opts, call->req.get(), call->resp.get(),
742 std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
743 }
744
745 // Waits for the RunGraph calls.
746 call_opts->SetCancelCallback([&calls]() {
747 LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
748 "worker operations.";
749 calls.StartCancel();
750 });
751 auto token = cm->get_cancellation_token();
752 const bool success =
753 cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
754 if (!success) {
755 calls.StartCancel();
756 }
757 calls.Wait();
758 call_opts->ClearCancelCallback();
759 if (success) {
760 cm->DeregisterCallback(token);
761 } else {
762 return errors::Cancelled("Step was cancelled");
763 }
764 TF_RETURN_IF_ERROR(calls.status());
765
766 // Collects fetches and metadata.
767 Status status;
768 for (int i = 0; i < num; ++i) {
769 const Part& part = partitions_[i];
770 MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
771 for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
772 auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
773 if (iter == part.key_fetch.end()) {
774 status.Update(errors::Internal("Unexpected fetch key: ",
775 run_graph_resp->recv_key(j)));
776 break;
777 }
778 const string& fetch = iter->second;
779 status.Update(
780 resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
781 if (!status.ok()) {
782 break;
783 }
784 }
785 if (pss->collect_timeline) {
786 pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
787 }
788 if (pss->collect_costs) {
789 CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
790 for (int j = 0; j < cost_graph->node_size(); ++j) {
791 resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
792 cost_graph->mutable_node(j));
793 }
794 }
795 if (pss->collect_partition_graphs) {
796 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
797 resp->mutable_metadata()->mutable_partition_graphs();
798 for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
799 partition_graph_defs->Add()->Swap(
800 run_graph_resp->mutable_partition_graph(i));
801 }
802 }
803 }
804 return status;
805 }
806
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)807 Status MasterSession::ReffedClientGraph::RunPartitions(
808 const MasterEnv* env, int64 step_id, int64 execution_count,
809 PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
810 MutableRunStepResponseWrapper* resp, CancellationManager* cm,
811 const bool is_last_partial_run) {
812 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
813 << execution_count;
814 // Maps the names of fed tensors to their index in `req`.
815 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
816 for (size_t i = 0; i < req.num_feeds(); ++i) {
817 if (!feeds.insert({req.feed_name(i), i}).second) {
818 return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
819 }
820 }
821
822 std::vector<string> fetches;
823 fetches.reserve(req.num_fetches());
824 for (size_t i = 0; i < req.num_fetches(); ++i) {
825 fetches.push_back(req.fetch_name(i));
826 }
827
828 return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
829 call_opts, req, resp, cm, is_last_partial_run);
830 }
831
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)832 Status MasterSession::ReffedClientGraph::RunPartitions(
833 const MasterEnv* env, int64 step_id, int64 execution_count,
834 PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
835 RunCallableResponse* resp, CancellationManager* cm) {
836 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
837 << execution_count;
838 // Maps the names of fed tensors to their index in `req`.
839 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
840 for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
841 if (!feeds.insert({callable_opts_.feed(i), i}).second) {
842 // MakeCallable will fail if there are two feeds with the same name.
843 return errors::Internal("Duplicated feeds in callable: ",
844 callable_opts_.feed(i));
845 }
846 }
847
848 // Create a wrapped response object to collect the fetched values and
849 // rearrange them for the RunCallableResponse.
850 RunCallableResponseWrapper wrapped_resp;
851 wrapped_resp.resp = resp;
852
853 TF_RETURN_IF_ERROR(RunPartitionsHelper(
854 feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
855 call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
856
857 // Collects fetches.
858 // TODO(b/74355905): Add a specialized implementation that avoids
859 // copying the tensor into the RunCallableResponse when at least
860 // two of the {client, master, worker} are in the same process.
861 for (const string& fetch : callable_opts_.fetch()) {
862 TensorProto* fetch_proto = resp->mutable_fetch()->Add();
863 auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
864 if (iter == wrapped_resp.fetch_key_to_protos.end()) {
865 return errors::Internal("Worker did not return a value for fetch: ",
866 fetch);
867 }
868 fetch_proto->Swap(&iter->second);
869 }
870 return Status::OK();
871 }
872
873 namespace {
874
875 class CleanupBroadcastHelper {
876 public:
CleanupBroadcastHelper(int64 step_id,int num_calls,StatusCallback done)877 CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
878 : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
879 req_.set_step_id(step_id);
880 }
881
882 // Returns a non-owned pointer to a request buffer for all calls.
request()883 CleanupGraphRequest* request() { return &req_; }
884
885 // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)886 CleanupGraphResponse* response(int i) { return &resps_[i]; }
887
888 // Called when the ith response is received.
call_done(int i,const Status & s)889 void call_done(int i, const Status& s) {
890 bool run_callback = false;
891 Status status_copy;
892 {
893 mutex_lock l(mu_);
894 status_.Update(s);
895 if (--num_pending_ == 0) {
896 run_callback = true;
897 status_copy = status_;
898 }
899 }
900 if (run_callback) {
901 done_(status_copy);
902 // This is the last call, so delete the helper object.
903 delete this;
904 }
905 }
906
907 private:
908 // A single request shared between all workers.
909 CleanupGraphRequest req_;
910 // One response buffer for each worker.
911 gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
912
913 mutex mu_;
914 // Number of requests remaining to be collected.
915 int num_pending_ TF_GUARDED_BY(mu_);
916 // Aggregate status of the operation.
917 Status status_ TF_GUARDED_BY(mu_);
918 // Callback to be called when all operations complete.
919 StatusCallback done_;
920
921 TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
922 };
923
924 } // namespace
925
CleanupPartitionsAsync(int64 step_id,StatusCallback done)926 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
927 int64 step_id, StatusCallback done) {
928 const int num = partitions_.size();
929 // Helper object will be deleted when the final call completes.
930 CleanupBroadcastHelper* helper =
931 new CleanupBroadcastHelper(step_id, num, std::move(done));
932 for (int i = 0; i < num; ++i) {
933 const Part& part = partitions_[i];
934 part.worker->CleanupGraphAsync(
935 helper->request(), helper->response(i),
936 [helper, i](const Status& s) { helper->call_done(i, s); });
937 }
938 }
939
ProcessStats(int64 step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)940 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
941 PerStepState* pss,
942 ProfileHandler* ph,
943 const RunOptions& options,
944 RunMetadata* resp) {
945 if (!pss->collect_costs && !pss->collect_timeline) return;
946
947 // Out-of-band logging data is collected now, during post-processing.
948 if (pss->collect_timeline) {
949 SetRPCLogging(false);
950 RetrieveLogs(step_id, &pss->rpc_stats);
951 }
952 for (size_t i = 0; i < partitions_.size(); ++i) {
953 const StepStats& ss = pss->step_stats[i];
954 if (ph) {
955 for (const auto& ds : ss.dev_stats()) {
956 ProcessDeviceStats(ph, ds, false /*is_rpc*/);
957 }
958 }
959 }
960 if (ph) {
961 for (const auto& ds : pss->rpc_stats.dev_stats()) {
962 ProcessDeviceStats(ph, ds, true /*is_rpc*/);
963 }
964 ph->StepDone(pss->start_micros, pss->end_micros,
965 Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
966 Status::OK());
967 }
968 // Assemble all stats for this timeline into a merged StepStats.
969 if (pss->collect_timeline) {
970 StepStats step_stats_proto;
971 step_stats_proto.Swap(&pss->rpc_stats);
972 for (size_t i = 0; i < partitions_.size(); ++i) {
973 step_stats_proto.MergeFrom(pss->step_stats[i]);
974 pss->step_stats[i].Clear();
975 }
976 pss->step_stats.clear();
977 // Copy the stats back, but only for on-demand profiling to avoid slowing
978 // down calls that trigger the automatic profiling.
979 if (options.trace_level() == RunOptions::FULL_TRACE) {
980 resp->mutable_step_stats()->Swap(&step_stats_proto);
981 } else {
982 // If FULL_TRACE, it can be fetched from Session API, no need for
983 // duplicated publishing.
984 stats_publisher_->PublishStatsProto(step_stats_proto);
985 }
986 }
987 }
988
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)989 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
990 ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
991 const string& dev_name = ds.device();
992 VLOG(1) << "Device " << dev_name << " reports stats for "
993 << ds.node_stats_size() << " nodes";
994 for (const auto& ns : ds.node_stats()) {
995 if (is_rpc) {
996 // We don't have access to a good Node pointer, so we rely on
997 // sufficient data being present in the NodeExecStats.
998 ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
999 ns.timeline_label());
1000 } else {
1001 auto iter = name_to_node_details_.find(ns.node_name());
1002 const bool found_node_in_graph = iter != name_to_node_details_.end();
1003 if (!found_node_in_graph && ns.timeline_label().empty()) {
1004 // The counter incrementing is not thread-safe. But we don't really
1005 // care.
1006 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1007 // more general usage.
1008 static int log_counter = 0;
1009 if (log_counter < 10) {
1010 log_counter++;
1011 LOG(WARNING) << "Failed to find node " << ns.node_name()
1012 << " for dev " << dev_name;
1013 }
1014 continue;
1015 }
1016 const string& optype =
1017 found_node_in_graph ? iter->second.type_string : ns.node_name();
1018 string details;
1019 if (!ns.timeline_label().empty()) {
1020 details = ns.timeline_label();
1021 } else if (found_node_in_graph) {
1022 details = DetailText(iter->second, ns);
1023 } else {
1024 // Leave details string empty
1025 }
1026 ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1027 details);
1028 }
1029 }
1030 }
1031
1032 // TODO(suharshs): Merge with CheckFetches in DirectSession.
1033 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1034 // on once at setup time to prevent us from computing the dependencies
1035 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)1036 Status MasterSession::ReffedClientGraph::CheckFetches(
1037 const RunStepRequestWrapper& req, const RunState* run_state,
1038 GraphExecutionState* execution_state) {
1039 // Build the set of pending feeds that we haven't seen.
1040 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1041 for (const auto& input : run_state->pending_inputs) {
1042 // Skip if already fed.
1043 if (input.second) continue;
1044 TensorId id(ParseTensorName(input.first));
1045 const Node* n = execution_state->get_node_by_name(string(id.first));
1046 if (n == nullptr) {
1047 return errors::NotFound("Feed ", input.first, ": not found");
1048 }
1049 pending_feeds.insert(id);
1050 }
1051 for (size_t i = 0; i < req.num_feeds(); ++i) {
1052 const TensorId id(ParseTensorName(req.feed_name(i)));
1053 pending_feeds.erase(id);
1054 }
1055
1056 // Initialize the stack with the fetch nodes.
1057 std::vector<const Node*> stack;
1058 for (size_t i = 0; i < req.num_fetches(); ++i) {
1059 const string& fetch = req.fetch_name(i);
1060 const TensorId id(ParseTensorName(fetch));
1061 const Node* n = execution_state->get_node_by_name(string(id.first));
1062 if (n == nullptr) {
1063 return errors::NotFound("Fetch ", fetch, ": not found");
1064 }
1065 stack.push_back(n);
1066 }
1067
1068 // Any tensor needed for fetches can't be in pending_feeds.
1069 // We need to use the original full graph from execution state.
1070 const Graph* graph = execution_state->full_graph();
1071 std::vector<bool> visited(graph->num_node_ids(), false);
1072 while (!stack.empty()) {
1073 const Node* n = stack.back();
1074 stack.pop_back();
1075
1076 for (const Edge* in_edge : n->in_edges()) {
1077 const Node* in_node = in_edge->src();
1078 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1079 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1080 in_edge->src_output(),
1081 " can't be computed from the feeds"
1082 " that have been fed so far.");
1083 }
1084 if (!visited[in_node->id()]) {
1085 visited[in_node->id()] = true;
1086 stack.push_back(in_node);
1087 }
1088 }
1089 }
1090 return Status::OK();
1091 }
1092
1093 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1094 // result.
DeregisterPartitions()1095 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1096 struct Call {
1097 DeregisterGraphRequest req;
1098 DeregisterGraphResponse resp;
1099 };
1100 for (Part& part : partitions_) {
1101 // The graph handle may be empty if we failed during partition registration.
1102 if (!part.graph_handle.empty()) {
1103 Call* c = new Call;
1104 c->req.set_session_handle(session_handle_);
1105 c->req.set_create_worker_session_called(!should_deregister_);
1106 c->req.set_graph_handle(part.graph_handle);
1107 // NOTE(mrry): We must capture `worker_cache_` since `this`
1108 // could be deleted before the callback is called.
1109 WorkerCacheInterface* worker_cache = worker_cache_;
1110 const string name = part.name;
1111 WorkerInterface* w = part.worker;
1112 CHECK_NOTNULL(w);
1113 auto cb = [worker_cache, c, name, w](const Status& s) {
1114 if (!s.ok()) {
1115 // This error is potentially benign, so we don't log at the
1116 // error level.
1117 LOG(INFO) << "DeregisterGraph error: " << s;
1118 }
1119 delete c;
1120 worker_cache->ReleaseWorker(name, w);
1121 };
1122 w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1123 }
1124 }
1125 }
1126
1127 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1128 void CopyAndSortStrings(size_t size,
1129 const std::function<string(size_t)>& input_accessor,
1130 protobuf::RepeatedPtrField<string>* output) {
1131 std::vector<string> temp;
1132 temp.reserve(size);
1133 for (size_t i = 0; i < size; ++i) {
1134 output->Add(input_accessor(i));
1135 }
1136 std::sort(output->begin(), output->end());
1137 }
1138 } // namespace
1139
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1140 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1141 const ConfigProto& config,
1142 BuildGraphOptions* opts) {
1143 CallableOptions* callable_opts = &opts->callable_options;
1144 CopyAndSortStrings(
1145 req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1146 callable_opts->mutable_feed());
1147 CopyAndSortStrings(
1148 req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1149 callable_opts->mutable_fetch());
1150 CopyAndSortStrings(
1151 req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1152 callable_opts->mutable_target());
1153
1154 if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1155 *callable_opts->mutable_run_options()->mutable_debug_options() =
1156 req.options().debug_options();
1157 }
1158
1159 opts->collective_graph_key =
1160 req.options().experimental().collective_graph_key();
1161 if (config.experimental().collective_deterministic_sequential_execution()) {
1162 opts->collective_order = GraphCollectiveOrder::kEdges;
1163 } else if (config.experimental().collective_nccl()) {
1164 opts->collective_order = GraphCollectiveOrder::kAttrs;
1165 }
1166 }
1167
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1168 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1169 BuildGraphOptions* opts) {
1170 CallableOptions* callable_opts = &opts->callable_options;
1171 CopyAndSortStrings(
1172 req.feed_size(), [&req](size_t i) { return req.feed(i); },
1173 callable_opts->mutable_feed());
1174 CopyAndSortStrings(
1175 req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1176 callable_opts->mutable_fetch());
1177 CopyAndSortStrings(
1178 req.target_size(), [&req](size_t i) { return req.target(i); },
1179 callable_opts->mutable_target());
1180
1181 // TODO(cais): Add TFDBG support to partial runs.
1182 }
1183
HashBuildGraphOptions(const BuildGraphOptions & opts)1184 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1185 uint64 h = 0x2b992ddfa23249d6ull;
1186 for (const string& name : opts.callable_options.feed()) {
1187 h = Hash64(name.c_str(), name.size(), h);
1188 }
1189 for (const string& name : opts.callable_options.target()) {
1190 h = Hash64(name.c_str(), name.size(), h);
1191 }
1192 for (const string& name : opts.callable_options.fetch()) {
1193 h = Hash64(name.c_str(), name.size(), h);
1194 }
1195
1196 const DebugOptions& debug_options =
1197 opts.callable_options.run_options().debug_options();
1198 if (!debug_options.debug_tensor_watch_opts().empty()) {
1199 const string watch_summary =
1200 SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1201 h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1202 }
1203
1204 return h;
1205 }
1206
BuildGraphOptionsString(const BuildGraphOptions & opts)1207 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1208 string buf;
1209 for (const string& name : opts.callable_options.feed()) {
1210 strings::StrAppend(&buf, " FdE: ", name);
1211 }
1212 strings::StrAppend(&buf, "\n");
1213 for (const string& name : opts.callable_options.target()) {
1214 strings::StrAppend(&buf, " TN: ", name);
1215 }
1216 strings::StrAppend(&buf, "\n");
1217 for (const string& name : opts.callable_options.fetch()) {
1218 strings::StrAppend(&buf, " FeE: ", name);
1219 }
1220 if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1221 strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1222 }
1223 strings::StrAppend(&buf, "\n");
1224 return buf;
1225 }
1226
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1227 MasterSession::MasterSession(
1228 const SessionOptions& opt, const MasterEnv* env,
1229 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1230 std::unique_ptr<WorkerCacheInterface> worker_cache,
1231 std::unique_ptr<DeviceSet> device_set,
1232 std::vector<string> filtered_worker_list,
1233 StatsPublisherFactory stats_publisher_factory)
1234 : session_opts_(opt),
1235 env_(env),
1236 handle_(strings::FpToString(random::New64())),
1237 remote_devs_(std::move(remote_devs)),
1238 worker_cache_(std::move(worker_cache)),
1239 devices_(std::move(device_set)),
1240 filtered_worker_list_(std::move(filtered_worker_list)),
1241 stats_publisher_factory_(std::move(stats_publisher_factory)),
1242 graph_version_(0),
1243 run_graphs_(5),
1244 partial_run_graphs_(5) {
1245 UpdateLastAccessTime();
1246 CHECK(devices_) << "device_set was null!";
1247
1248 VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1249 << " #remote " << remote_devs_->size();
1250 VLOG(1) << "Start master session " << handle_
1251 << " with config: " << session_opts_.config.ShortDebugString();
1252 }
1253
~MasterSession()1254 MasterSession::~MasterSession() {
1255 for (const auto& iter : run_graphs_) iter.second->Unref();
1256 for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1257 }
1258
UpdateLastAccessTime()1259 void MasterSession::UpdateLastAccessTime() {
1260 last_access_time_usec_.store(Env::Default()->NowMicros());
1261 }
1262
Create(GraphDef && graph_def,const WorkerCacheFactoryOptions & options)1263 Status MasterSession::Create(GraphDef&& graph_def,
1264 const WorkerCacheFactoryOptions& options) {
1265 if (session_opts_.config.use_per_session_threads() ||
1266 session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1267 return errors::InvalidArgument(
1268 "Distributed session does not support session thread pool options.");
1269 }
1270 if (session_opts_.config.graph_options().place_pruned_graph()) {
1271 // TODO(b/29900832): Fix this or remove the option.
1272 LOG(WARNING) << "Distributed session does not support the "
1273 "place_pruned_graph option.";
1274 session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1275 }
1276
1277 GraphExecutionStateOptions execution_options;
1278 execution_options.device_set = devices_.get();
1279 execution_options.session_options = &session_opts_;
1280 {
1281 mutex_lock l(mu_);
1282 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1283 std::move(graph_def), execution_options, &execution_state_));
1284 }
1285 should_delete_worker_sessions_ = true;
1286 return CreateWorkerSessions(options);
1287 }
1288
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1289 Status MasterSession::CreateWorkerSessions(
1290 const WorkerCacheFactoryOptions& options) {
1291 const std::vector<string> worker_names = filtered_worker_list_;
1292 WorkerCacheInterface* worker_cache = get_worker_cache();
1293
1294 struct WorkerGroup {
1295 // The worker name. (Not owned.)
1296 const string* name;
1297
1298 // The worker referenced by name. (Not owned.)
1299 WorkerInterface* worker = nullptr;
1300
1301 // Request and responses used for a given worker.
1302 CreateWorkerSessionRequest request;
1303 CreateWorkerSessionResponse response;
1304 Status status = Status::OK();
1305 };
1306 BlockingCounter done(worker_names.size());
1307 std::vector<WorkerGroup> workers(worker_names.size());
1308
1309 // Release the workers.
1310 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1311 for (auto&& worker_group : workers) {
1312 if (worker_group.worker != nullptr) {
1313 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1314 }
1315 }
1316 });
1317
1318 string task_name;
1319 string local_device_name;
1320 DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1321 &task_name, &local_device_name);
1322 const int64 client_device_incarnation =
1323 devices_->client_device()->attributes().incarnation();
1324
1325 Status status = Status::OK();
1326 // Create all the workers & kick off the computations.
1327 for (size_t i = 0; i < worker_names.size(); ++i) {
1328 workers[i].name = &worker_names[i];
1329 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1330 workers[i].request.set_session_handle(handle_);
1331 workers[i].request.set_master_task(task_name);
1332 workers[i].request.set_master_incarnation(client_device_incarnation);
1333 if (session_opts_.config.share_cluster_devices_in_session() ||
1334 session_opts_.config.experimental()
1335 .share_cluster_devices_in_session()) {
1336 for (const auto& remote_dev : devices_->devices()) {
1337 *workers[i].request.add_cluster_device_attributes() =
1338 remote_dev->attributes();
1339 }
1340
1341 if (!session_opts_.config.share_cluster_devices_in_session() &&
1342 session_opts_.config.experimental()
1343 .share_cluster_devices_in_session()) {
1344 LOG(WARNING)
1345 << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1346 "been promoted to a non-experimental API. Please use "
1347 "ConfigProto.share_cluster_devices_in_session instead. The "
1348 "experimental option will be removed in the future.";
1349 }
1350 }
1351
1352 DeviceNameUtils::ParsedName name;
1353 if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1354 status = errors::Internal("Could not parse name ", worker_names[i]);
1355 LOG(WARNING) << status;
1356 return status;
1357 }
1358 if (!name.has_job || !name.has_task) {
1359 status = errors::Internal("Incomplete worker name ", worker_names[i]);
1360 LOG(WARNING) << status;
1361 return status;
1362 }
1363
1364 if (options.cluster_def) {
1365 *workers[i].request.mutable_server_def()->mutable_cluster() =
1366 *options.cluster_def;
1367 workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1368 workers[i].request.mutable_server_def()->set_job_name(name.job);
1369 workers[i].request.mutable_server_def()->set_task_index(name.task);
1370 // Session state is always isolated when ClusterSpec propagation
1371 // is in use.
1372 workers[i].request.set_isolate_session_state(true);
1373 } else {
1374 // NOTE(mrry): Do not set any component of the ServerDef,
1375 // because the worker will use its local configuration.
1376 workers[i].request.set_isolate_session_state(
1377 session_opts_.config.isolate_session_state());
1378 }
1379 if (session_opts_.config.experimental()
1380 .share_session_state_in_clusterspec_propagation()) {
1381 // In a dynamic cluster, the ClusterSpec info is usually propagated by
1382 // master sessions. However, in data parallel training with multiple
1383 // masters
1384 // ("between-graph replication"), we need to disable isolation for
1385 // different worker sessions to update the same variables in PS tasks.
1386 workers[i].request.set_isolate_session_state(false);
1387 }
1388 }
1389
1390 for (size_t i = 0; i < worker_names.size(); ++i) {
1391 auto cb = [i, &workers, &done](const Status& s) {
1392 workers[i].status = s;
1393 done.DecrementCount();
1394 };
1395 workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1396 &workers[i].response, cb);
1397 }
1398
1399 done.Wait();
1400 for (size_t i = 0; i < workers.size(); ++i) {
1401 status.Update(workers[i].status);
1402 }
1403 return status;
1404 }
1405
DeleteWorkerSessions()1406 Status MasterSession::DeleteWorkerSessions() {
1407 WorkerCacheInterface* worker_cache = get_worker_cache();
1408 const std::vector<string>& worker_names = filtered_worker_list_;
1409
1410 struct WorkerGroup {
1411 // The worker name. (Not owned.)
1412 const string* name;
1413
1414 // The worker referenced by name. (Not owned.)
1415 WorkerInterface* worker = nullptr;
1416
1417 CallOptions call_opts;
1418
1419 // Request and responses used for a given worker.
1420 DeleteWorkerSessionRequest request;
1421 DeleteWorkerSessionResponse response;
1422 Status status = Status::OK();
1423 };
1424 BlockingCounter done(worker_names.size());
1425 std::vector<WorkerGroup> workers(worker_names.size());
1426
1427 // Release the workers.
1428 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1429 for (auto&& worker_group : workers) {
1430 if (worker_group.worker != nullptr) {
1431 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1432 }
1433 }
1434 });
1435
1436 Status status = Status::OK();
1437 // Create all the workers & kick off the computations.
1438 for (size_t i = 0; i < worker_names.size(); ++i) {
1439 workers[i].name = &worker_names[i];
1440 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1441 workers[i].request.set_session_handle(handle_);
1442 // Since the worker may have gone away, set a timeout to avoid blocking the
1443 // session-close operation.
1444 workers[i].call_opts.SetTimeout(10000);
1445 }
1446
1447 for (size_t i = 0; i < worker_names.size(); ++i) {
1448 auto cb = [i, &workers, &done](const Status& s) {
1449 workers[i].status = s;
1450 done.DecrementCount();
1451 };
1452 workers[i].worker->DeleteWorkerSessionAsync(
1453 &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1454 }
1455
1456 done.Wait();
1457 for (size_t i = 0; i < workers.size(); ++i) {
1458 status.Update(workers[i].status);
1459 }
1460 return status;
1461 }
1462
ListDevices(ListDevicesResponse * resp) const1463 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1464 if (worker_cache_) {
1465 // This is a ClusterSpec-propagated session, and thus env_->local_devices
1466 // are invalid.
1467
1468 // Mark the "client_device" as the sole local device.
1469 const Device* client_device = devices_->client_device();
1470 for (const Device* dev : devices_->devices()) {
1471 if (dev != client_device) {
1472 *(resp->add_remote_device()) = dev->attributes();
1473 }
1474 }
1475 *(resp->add_local_device()) = client_device->attributes();
1476 } else {
1477 for (Device* dev : env_->local_devices) {
1478 *(resp->add_local_device()) = dev->attributes();
1479 }
1480 for (auto&& dev : *remote_devs_) {
1481 *(resp->add_local_device()) = dev->attributes();
1482 }
1483 }
1484 return Status::OK();
1485 }
1486
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1487 Status MasterSession::Extend(const ExtendSessionRequest* req,
1488 ExtendSessionResponse* resp) {
1489 UpdateLastAccessTime();
1490 std::unique_ptr<GraphExecutionState> extended_execution_state;
1491 {
1492 mutex_lock l(mu_);
1493 if (closed_) {
1494 return errors::FailedPrecondition("Session is closed.");
1495 }
1496
1497 if (graph_version_ != req->current_graph_version()) {
1498 return errors::Aborted("Current version is ", graph_version_,
1499 " but caller expected ",
1500 req->current_graph_version(), ".");
1501 }
1502
1503 CHECK(execution_state_);
1504 TF_RETURN_IF_ERROR(
1505 execution_state_->Extend(req->graph_def(), &extended_execution_state));
1506
1507 CHECK(extended_execution_state);
1508 // The old execution state will be released outside the lock.
1509 execution_state_.swap(extended_execution_state);
1510 ++graph_version_;
1511 resp->set_new_graph_version(graph_version_);
1512 }
1513 return Status::OK();
1514 }
1515
get_worker_cache() const1516 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1517 if (worker_cache_) {
1518 return worker_cache_.get();
1519 }
1520 return env_->worker_cache;
1521 }
1522
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1523 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1524 ReffedClientGraph** out_rcg, int64* out_count) {
1525 const uint64 hash = HashBuildGraphOptions(opts);
1526 {
1527 mutex_lock l(mu_);
1528 // TODO(suharshs): We cache partial run graphs and run graphs separately
1529 // because there is preprocessing that needs to only be run for partial
1530 // run calls.
1531 RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1532 auto iter = m->find(hash);
1533 if (iter == m->end()) {
1534 // We have not seen this subgraph before. Build the subgraph and
1535 // cache it.
1536 VLOG(1) << "Unseen hash " << hash << " for "
1537 << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1538 << "\n";
1539 std::unique_ptr<ClientGraph> client_graph;
1540 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1541 WorkerCacheInterface* worker_cache = get_worker_cache();
1542 auto entry = new ReffedClientGraph(
1543 handle_, opts, std::move(client_graph), session_opts_,
1544 stats_publisher_factory_, is_partial, worker_cache,
1545 !should_delete_worker_sessions_);
1546 iter = m->insert({hash, entry}).first;
1547 VLOG(1) << "Preparing to execute new graph";
1548 }
1549 *out_rcg = iter->second;
1550 (*out_rcg)->Ref();
1551 *out_count = (*out_rcg)->get_and_increment_execution_count();
1552 }
1553 return Status::OK();
1554 }
1555
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1556 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1557 RCGMap* rcg_map) {
1558 VLOG(1) << "Discarding all reffed graphs";
1559 for (auto p : *rcg_map) {
1560 ReffedClientGraph* rcg = p.second;
1561 if (to_unref) {
1562 to_unref->push_back(rcg);
1563 } else {
1564 rcg->Unref();
1565 }
1566 }
1567 rcg_map->clear();
1568 }
1569
NewStepId(int64 graph_key)1570 uint64 MasterSession::NewStepId(int64 graph_key) {
1571 if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1572 // StepId must leave the most-significant 7 bits empty for future use.
1573 return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1574 } else {
1575 uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1576 int32 retry_count = 0;
1577 while (static_cast<int64>(step_id) == CollectiveExecutor::kInvalidId) {
1578 Notification note;
1579 Status status;
1580 env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1581 graph_key, [&status, ¬e](const Status& s) {
1582 status = s;
1583 note.Notify();
1584 });
1585 note.WaitForNotification();
1586 if (!status.ok()) {
1587 LOG(ERROR) << "Bad status from "
1588 "collective_executor_mgr->RefreshStepIdSequence: "
1589 << status << ". Retrying.";
1590 int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1591 Env::Default()->SleepForMicroseconds(delay_micros);
1592 } else {
1593 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1594 }
1595 }
1596 return step_id;
1597 }
1598 }
1599
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1600 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1601 PartialRunSetupResponse* resp) {
1602 std::vector<string> inputs, outputs, targets;
1603 for (const auto& feed : req->feed()) {
1604 inputs.push_back(feed);
1605 }
1606 for (const auto& fetch : req->fetch()) {
1607 outputs.push_back(fetch);
1608 }
1609 for (const auto& target : req->target()) {
1610 targets.push_back(target);
1611 }
1612
1613 string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1614
1615 ReffedClientGraph* rcg = nullptr;
1616
1617 // Prepare.
1618 BuildGraphOptions opts;
1619 BuildBuildGraphOptions(*req, &opts);
1620 int64 count = 0;
1621 TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1622
1623 rcg->Ref();
1624 RunState* run_state =
1625 new RunState(inputs, outputs, rcg,
1626 NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1627 {
1628 mutex_lock l(mu_);
1629 partial_runs_.emplace(
1630 std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1631 }
1632
1633 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1634
1635 resp->set_partial_run_handle(handle);
1636 return Status::OK();
1637 }
1638
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1639 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1640 MutableRunStepResponseWrapper* resp) {
1641 UpdateLastAccessTime();
1642 {
1643 mutex_lock l(mu_);
1644 if (closed_) {
1645 return errors::FailedPrecondition("Session is closed.");
1646 }
1647 ++num_running_;
1648 // Note: all code paths must eventually call MarkRunCompletion()
1649 // in order to appropriate decrement the num_running_ counter.
1650 }
1651 Status status;
1652 if (!req.partial_run_handle().empty()) {
1653 status = DoPartialRun(opts, req, resp);
1654 } else {
1655 status = DoRunWithLocalExecution(opts, req, resp);
1656 }
1657 return status;
1658 }
1659
1660 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1661 void MasterSession::MarkRunCompletion() {
1662 mutex_lock l(mu_);
1663 --num_running_;
1664 if (num_running_ == 0) {
1665 num_running_is_zero_.notify_all();
1666 }
1667 }
1668
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1669 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1670 // Registers subgraphs if haven't done so.
1671 PartitionOptions popts;
1672 popts.node_to_loc = SplitByWorker;
1673 // The closures popts.{new_name,get_incarnation} are called synchronously in
1674 // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1675 // "this" alive during the closure.
1676 popts.new_name = [this](const string& prefix) {
1677 mutex_lock l(mu_);
1678 return strings::StrCat(prefix, "_S", next_node_id_++);
1679 };
1680 popts.get_incarnation = [this](const string& name) -> int64 {
1681 Device* d = devices_->FindDeviceByName(name);
1682 if (d == nullptr) {
1683 return PartitionOptions::kIllegalIncarnation;
1684 } else {
1685 return d->attributes().incarnation();
1686 }
1687 };
1688 popts.control_flow_added = false;
1689 const bool enable_bfloat16_sendrecv =
1690 session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1691 popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1692 if (e->IsControlEdge()) {
1693 return DT_FLOAT;
1694 }
1695 DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1696 if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1697 return DT_BFLOAT16;
1698 } else {
1699 return dtype;
1700 }
1701 };
1702 if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1703 popts.scheduling_for_recvs = true;
1704 popts.need_to_record_start_times = true;
1705 }
1706
1707 TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1708
1709 return Status::OK();
1710 }
1711
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1712 Status MasterSession::DoPartialRun(CallOptions* opts,
1713 const RunStepRequestWrapper& req,
1714 MutableRunStepResponseWrapper* resp) {
1715 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1716 const string& prun_handle = req.partial_run_handle();
1717 RunState* run_state = nullptr;
1718 {
1719 mutex_lock l(mu_);
1720 auto it = partial_runs_.find(prun_handle);
1721 if (it == partial_runs_.end()) {
1722 return errors::InvalidArgument(
1723 "Must run PartialRunSetup before performing partial runs");
1724 }
1725 run_state = it->second.get();
1726 }
1727 // CollectiveOps are not supported in partial runs.
1728 if (req.options().experimental().collective_graph_key() !=
1729 BuildGraphOptions::kNoCollectiveGraphKey) {
1730 return errors::InvalidArgument(
1731 "PartialRun does not support Collective ops. collective_graph_key "
1732 "must be kNoCollectiveGraphKey.");
1733 }
1734
1735 // If this is the first partial run, initialize the PerStepState.
1736 if (!run_state->step_started) {
1737 run_state->step_started = true;
1738 PerStepState pss;
1739
1740 const auto count = run_state->count;
1741 pss.collect_timeline =
1742 req.options().trace_level() == RunOptions::FULL_TRACE;
1743 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1744 pss.report_tensor_allocations_upon_oom =
1745 req.options().report_tensor_allocations_upon_oom();
1746
1747 // Build the cost model every 'build_cost_model_every' steps after skipping
1748 // an
1749 // initial 'build_cost_model_after' steps.
1750 const int64 build_cost_model_after =
1751 session_opts_.config.graph_options().build_cost_model_after();
1752 const int64 build_cost_model_every =
1753 session_opts_.config.graph_options().build_cost_model();
1754 pss.collect_costs =
1755 build_cost_model_every > 0 &&
1756 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1757 pss.collect_partition_graphs = req.options().output_partition_graphs();
1758
1759 std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1760 run_state->step_id, count, req.options());
1761 if (ph) {
1762 pss.collect_timeline = true;
1763 pss.collect_rpcs = ph->should_collect_rpcs();
1764 }
1765
1766 run_state->pss = std::move(pss);
1767 run_state->ph = std::move(ph);
1768 }
1769
1770 // Make sure that this is a new set of feeds that are still pending.
1771 for (size_t i = 0; i < req.num_feeds(); ++i) {
1772 const string& feed = req.feed_name(i);
1773 auto it = run_state->pending_inputs.find(feed);
1774 if (it == run_state->pending_inputs.end()) {
1775 return errors::InvalidArgument(
1776 "The feed ", feed, " was not specified in partial_run_setup.");
1777 } else if (it->second) {
1778 return errors::InvalidArgument("The feed ", feed,
1779 " has already been fed.");
1780 }
1781 }
1782 // Check that this is a new set of fetches that are still pending.
1783 for (size_t i = 0; i < req.num_fetches(); ++i) {
1784 const string& fetch = req.fetch_name(i);
1785 auto it = run_state->pending_outputs.find(fetch);
1786 if (it == run_state->pending_outputs.end()) {
1787 return errors::InvalidArgument(
1788 "The fetch ", fetch, " was not specified in partial_run_setup.");
1789 } else if (it->second) {
1790 return errors::InvalidArgument("The fetch ", fetch,
1791 " has already been fetched.");
1792 }
1793 }
1794
1795 // Ensure that the requested fetches can be computed from the provided feeds.
1796 {
1797 mutex_lock l(mu_);
1798 TF_RETURN_IF_ERROR(
1799 run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1800 }
1801
1802 // Determine if this partial run satisfies all the pending inputs and outputs.
1803 for (size_t i = 0; i < req.num_feeds(); ++i) {
1804 auto it = run_state->pending_inputs.find(req.feed_name(i));
1805 it->second = true;
1806 }
1807 for (size_t i = 0; i < req.num_fetches(); ++i) {
1808 auto it = run_state->pending_outputs.find(req.fetch_name(i));
1809 it->second = true;
1810 }
1811 bool is_last_partial_run = run_state->PendingDone();
1812
1813 Status s = run_state->rcg->RunPartitions(
1814 env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1815 resp, &cancellation_manager_, is_last_partial_run);
1816
1817 // Delete the run state if there is an error or all fetches are done.
1818 if (!s.ok() || is_last_partial_run) {
1819 ReffedClientGraph* rcg = run_state->rcg;
1820 run_state->pss.end_micros = Env::Default()->NowMicros();
1821 // Schedule post-processing and cleanup to be done asynchronously.
1822 Ref();
1823 rcg->Ref();
1824 rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1825 req.options(), resp->mutable_metadata());
1826 cleanup.release(); // MarkRunCompletion called in done closure.
1827 rcg->CleanupPartitionsAsync(
1828 run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1829 if (!s.ok()) {
1830 LOG(ERROR) << "Cleanup partition error: " << s;
1831 }
1832 rcg->Unref();
1833 MarkRunCompletion();
1834 Unref();
1835 });
1836 mutex_lock l(mu_);
1837 partial_runs_.erase(prun_handle);
1838 }
1839 return s;
1840 }
1841
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64 rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1842 Status MasterSession::CreateDebuggerState(
1843 const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1844 int64 rcg_execution_count,
1845 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1846 TF_RETURN_IF_ERROR(
1847 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1848
1849 std::vector<string> input_names;
1850 for (size_t i = 0; i < req.num_feeds(); ++i) {
1851 input_names.push_back(req.feed_name(i));
1852 }
1853 std::vector<string> output_names;
1854 for (size_t i = 0; i < req.num_fetches(); ++i) {
1855 output_names.push_back(req.fetch_name(i));
1856 }
1857 std::vector<string> target_names;
1858 for (size_t i = 0; i < req.num_targets(); ++i) {
1859 target_names.push_back(req.target_name(i));
1860 }
1861
1862 // TODO(cais): We currently use -1 as a dummy value for session run count.
1863 // While this counter value is straightforward to define and obtain for
1864 // DirectSessions, it is less so for non-direct Sessions. Devise a better
1865 // way to get its value when the need arises.
1866 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1867 debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1868 input_names, output_names, target_names));
1869
1870 return Status::OK();
1871 }
1872
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64 count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1873 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1874 const RunOptions& run_options,
1875 uint64 step_id, int64 count,
1876 PerStepState* out_pss,
1877 std::unique_ptr<ProfileHandler>* out_ph) {
1878 out_pss->collect_timeline =
1879 run_options.trace_level() == RunOptions::FULL_TRACE;
1880 out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1881 out_pss->report_tensor_allocations_upon_oom =
1882 run_options.report_tensor_allocations_upon_oom();
1883 // Build the cost model every 'build_cost_model_every' steps after skipping an
1884 // initial 'build_cost_model_after' steps.
1885 const int64 build_cost_model_after =
1886 session_opts_.config.graph_options().build_cost_model_after();
1887 const int64 build_cost_model_every =
1888 session_opts_.config.graph_options().build_cost_model();
1889 out_pss->collect_costs =
1890 build_cost_model_every > 0 &&
1891 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1892 out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1893
1894 *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1895 if (*out_ph) {
1896 out_pss->collect_timeline = true;
1897 out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1898 }
1899 }
1900
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1901 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1902 uint64 step_id,
1903 const RunOptions& run_options,
1904 PerStepState* pss,
1905 const std::unique_ptr<ProfileHandler>& ph,
1906 const Status& run_status,
1907 RunMetadata* out_run_metadata) {
1908 Status s = run_status;
1909 if (s.ok()) {
1910 pss->end_micros = Env::Default()->NowMicros();
1911 if (rcg->collective_graph_key() !=
1912 BuildGraphOptions::kNoCollectiveGraphKey) {
1913 env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1914 step_id);
1915 }
1916 // Schedule post-processing and cleanup to be done asynchronously.
1917 rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1918 } else if (errors::IsCancelled(s)) {
1919 mutex_lock l(mu_);
1920 if (closed_) {
1921 if (garbage_collected_) {
1922 s = errors::Cancelled(
1923 "Step was cancelled because the session was garbage collected due "
1924 "to inactivity.");
1925 } else {
1926 s = errors::Cancelled(
1927 "Step was cancelled by an explicit call to `Session::Close()`.");
1928 }
1929 }
1930 }
1931 Ref();
1932 rcg->Ref();
1933 rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1934 if (!s.ok()) {
1935 LOG(ERROR) << "Cleanup partition error: " << s;
1936 }
1937 rcg->Unref();
1938 MarkRunCompletion();
1939 Unref();
1940 });
1941 return s;
1942 }
1943
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1944 Status MasterSession::DoRunWithLocalExecution(
1945 CallOptions* opts, const RunStepRequestWrapper& req,
1946 MutableRunStepResponseWrapper* resp) {
1947 VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1948 PerStepState pss;
1949 pss.start_micros = Env::Default()->NowMicros();
1950 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1951
1952 // Prepare.
1953 BuildGraphOptions bgopts;
1954 BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1955 ReffedClientGraph* rcg = nullptr;
1956 int64 count;
1957 TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1958
1959 // Unref "rcg" when out of scope.
1960 core::ScopedUnref unref(rcg);
1961
1962 std::unique_ptr<DebuggerStateInterface> debugger_state;
1963 const DebugOptions& debug_options = req.options().debug_options();
1964
1965 if (!debug_options.debug_tensor_watch_opts().empty()) {
1966 TF_RETURN_IF_ERROR(
1967 CreateDebuggerState(debug_options, req, count, &debugger_state));
1968 }
1969 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1970
1971 // Keeps the highest 8 bits 0x01: we reserve some bits of the
1972 // step_id for future use.
1973 uint64 step_id = NewStepId(rcg->collective_graph_key());
1974 TRACEPRINTF("stepid %llu", step_id);
1975
1976 std::unique_ptr<ProfileHandler> ph;
1977 FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1978
1979 if (pss.collect_partition_graphs &&
1980 session_opts_.config.experimental().disable_output_partition_graphs()) {
1981 return errors::InvalidArgument(
1982 "RunOptions.output_partition_graphs() is not supported when "
1983 "disable_output_partition_graphs is true.");
1984 }
1985
1986 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1987 &cancellation_manager_, false);
1988
1989 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
1990 return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1991 resp->mutable_metadata());
1992 }
1993
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1994 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1995 MakeCallableResponse* resp) {
1996 UpdateLastAccessTime();
1997
1998 BuildGraphOptions opts;
1999 opts.callable_options = req.options();
2000 opts.use_function_convention = false;
2001
2002 ReffedClientGraph* callable;
2003
2004 {
2005 mutex_lock l(mu_);
2006 if (closed_) {
2007 return errors::FailedPrecondition("Session is closed.");
2008 }
2009 std::unique_ptr<ClientGraph> client_graph;
2010 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2011 callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2012 session_opts_, stats_publisher_factory_,
2013 false /* is_partial */, get_worker_cache(),
2014 !should_delete_worker_sessions_);
2015 }
2016
2017 Status s = BuildAndRegisterPartitions(callable);
2018 if (!s.ok()) {
2019 callable->Unref();
2020 return s;
2021 }
2022
2023 uint64 handle;
2024 {
2025 mutex_lock l(mu_);
2026 handle = next_callable_handle_++;
2027 callables_[handle] = callable;
2028 }
2029
2030 resp->set_handle(handle);
2031 return Status::OK();
2032 }
2033
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)2034 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2035 const RunCallableRequest& req,
2036 RunCallableResponse* resp) {
2037 VLOG(2) << "DoRunCallable req: " << req.DebugString();
2038 PerStepState pss;
2039 pss.start_micros = Env::Default()->NowMicros();
2040 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2041
2042 // Prepare.
2043 int64 count = rcg->get_and_increment_execution_count();
2044
2045 const uint64 step_id = NewStepId(rcg->collective_graph_key());
2046 TRACEPRINTF("stepid %llu", step_id);
2047
2048 const RunOptions& run_options = rcg->callable_options().run_options();
2049
2050 if (run_options.timeout_in_ms() != 0) {
2051 opts->SetTimeout(run_options.timeout_in_ms());
2052 }
2053
2054 std::unique_ptr<ProfileHandler> ph;
2055 FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2056 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2057 &cancellation_manager_);
2058 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
2059 return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2060 resp->mutable_metadata());
2061 }
2062
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)2063 Status MasterSession::RunCallable(CallOptions* opts,
2064 const RunCallableRequest& req,
2065 RunCallableResponse* resp) {
2066 UpdateLastAccessTime();
2067 ReffedClientGraph* callable;
2068 {
2069 mutex_lock l(mu_);
2070 if (closed_) {
2071 return errors::FailedPrecondition("Session is closed.");
2072 }
2073 int64 handle = req.handle();
2074 if (handle >= next_callable_handle_) {
2075 return errors::InvalidArgument("No such callable handle: ", handle);
2076 }
2077 auto iter = callables_.find(req.handle());
2078 if (iter == callables_.end()) {
2079 return errors::InvalidArgument(
2080 "Attempted to run callable after handle was released: ", handle);
2081 }
2082 callable = iter->second;
2083 callable->Ref();
2084 ++num_running_;
2085 }
2086 core::ScopedUnref unref_callable(callable);
2087 return DoRunCallable(opts, callable, req, resp);
2088 }
2089
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)2090 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2091 ReleaseCallableResponse* resp) {
2092 UpdateLastAccessTime();
2093 ReffedClientGraph* to_unref = nullptr;
2094 {
2095 mutex_lock l(mu_);
2096 auto iter = callables_.find(req.handle());
2097 if (iter != callables_.end()) {
2098 to_unref = iter->second;
2099 callables_.erase(iter);
2100 }
2101 }
2102 if (to_unref != nullptr) {
2103 to_unref->Unref();
2104 }
2105 return Status::OK();
2106 }
2107
Close()2108 Status MasterSession::Close() {
2109 {
2110 mutex_lock l(mu_);
2111 closed_ = true; // All subsequent calls to Run() or Extend() will fail.
2112 }
2113 cancellation_manager_.StartCancel();
2114 std::vector<ReffedClientGraph*> to_unref;
2115 {
2116 mutex_lock l(mu_);
2117 while (num_running_ != 0) {
2118 num_running_is_zero_.wait(l);
2119 }
2120 ClearRunsTable(&to_unref, &run_graphs_);
2121 ClearRunsTable(&to_unref, &partial_run_graphs_);
2122 ClearRunsTable(&to_unref, &callables_);
2123 }
2124 for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2125 if (should_delete_worker_sessions_) {
2126 Status s = DeleteWorkerSessions();
2127 if (!s.ok()) {
2128 LOG(WARNING) << s;
2129 }
2130 }
2131 return Status::OK();
2132 }
2133
GarbageCollect()2134 void MasterSession::GarbageCollect() {
2135 {
2136 mutex_lock l(mu_);
2137 closed_ = true;
2138 garbage_collected_ = true;
2139 }
2140 cancellation_manager_.StartCancel();
2141 Unref();
2142 }
2143
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64 count)2144 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2145 const std::vector<string>& output_names,
2146 ReffedClientGraph* rcg, const uint64 step_id,
2147 const int64 count)
2148 : rcg(rcg), step_id(step_id), count(count) {
2149 // Initially all the feeds and fetches are pending.
2150 for (auto& name : input_names) {
2151 pending_inputs[name] = false;
2152 }
2153 for (auto& name : output_names) {
2154 pending_outputs[name] = false;
2155 }
2156 }
2157
~RunState()2158 MasterSession::RunState::~RunState() {
2159 if (rcg) rcg->Unref();
2160 }
2161
PendingDone() const2162 bool MasterSession::RunState::PendingDone() const {
2163 for (const auto& it : pending_inputs) {
2164 if (!it.second) return false;
2165 }
2166 for (const auto& it : pending_outputs) {
2167 if (!it.second) return false;
2168 }
2169 return true;
2170 }
2171
2172 } // end namespace tensorflow
2173