1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/request_id.h"
28 #include "tensorflow/core/distributed_runtime/scheduler.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/framework/allocation_description.pb.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/cost_graph.pb.h"
34 #include "tensorflow/core/framework/graph_def_util.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_description.pb.h"
40 #include "tensorflow/core/graph/graph_partition.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/blocking_counter.h"
43 #include "tensorflow/core/lib/core/notification.h"
44 #include "tensorflow/core/lib/core/refcount.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/tracing.h"
59 #include "tensorflow/core/public/session_options.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61
62 namespace tensorflow {
63
64 // MasterSession wraps ClientGraph in a reference counted object.
65 // This way, MasterSession can clear up the cache mapping Run requests to
66 // compiled graphs while the compiled graph is still being used.
67 //
68 // TODO(zhifengc): Cleanup this class. It's becoming messy.
69 class MasterSession::ReffedClientGraph : public core::RefCounted {
70 public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)71 ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
72 std::unique_ptr<ClientGraph> client_graph,
73 const SessionOptions& session_opts,
74 const StatsPublisherFactory& stats_publisher_factory,
75 bool is_partial, WorkerCacheInterface* worker_cache,
76 bool should_deregister)
77 : session_handle_(handle),
78 bg_opts_(bopts),
79 client_graph_before_register_(std::move(client_graph)),
80 session_opts_(session_opts),
81 is_partial_(is_partial),
82 callable_opts_(bopts.callable_options),
83 worker_cache_(worker_cache),
84 should_deregister_(should_deregister),
85 collective_graph_key_(
86 client_graph_before_register_->collective_graph_key) {
87 VLOG(1) << "Created ReffedClientGraph for node with "
88 << client_graph_before_register_->graph.num_node_ids();
89
90 stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
91
92 // Initialize a name to node map for processing device stats.
93 for (Node* n : client_graph_before_register_->graph.nodes()) {
94 name_to_node_details_.emplace(
95 n->name(),
96 NodeDetails(n->type_string(),
97 strings::StrCat(
98 "(", absl::StrJoin(n->requested_inputs(), ", "))));
99 }
100 }
101
~ReffedClientGraph()102 ~ReffedClientGraph() override {
103 if (should_deregister_) {
104 DeregisterPartitions();
105 } else {
106 for (Part& part : partitions_) {
107 worker_cache_->ReleaseWorker(part.name, part.worker);
108 }
109 }
110 }
111
callable_options()112 const CallableOptions& callable_options() { return callable_opts_; }
113
build_graph_options()114 const BuildGraphOptions& build_graph_options() { return bg_opts_; }
115
collective_graph_key()116 int64 collective_graph_key() { return collective_graph_key_; }
117
GetProfileHandler(uint64 step,int64_t execution_count,const RunOptions & ropts)118 std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
119 int64_t execution_count,
120 const RunOptions& ropts) {
121 return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
122 }
123
get_and_increment_execution_count()124 int64 get_and_increment_execution_count() {
125 return execution_count_.fetch_add(1);
126 }
127
128 // Turn RPC logging on or off, both at the WorkerCache used by this
129 // master process, and at each remote worker in use for the current
130 // partitions.
SetRPCLogging(bool active)131 void SetRPCLogging(bool active) {
132 worker_cache_->SetLogging(active);
133 // Logging is a best-effort activity, so we make async calls to turn
134 // it on/off and don't make use of the responses.
135 for (auto& p : partitions_) {
136 LoggingRequest* req = new LoggingRequest;
137 if (active) {
138 req->set_enable_rpc_logging(true);
139 } else {
140 req->set_disable_rpc_logging(true);
141 }
142 LoggingResponse* resp = new LoggingResponse;
143 Ref();
144 p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
145 delete req;
146 delete resp;
147 // ReffedClientGraph owns p.worker so we need to hold a ref to
148 // ensure that the method doesn't attempt to access p.worker after
149 // ReffedClient graph has deleted it.
150 // TODO(suharshs): Simplify this ownership model.
151 Unref();
152 });
153 }
154 }
155
156 // Retrieve all RPC logs data accumulated for the current step, both
157 // from the local WorkerCache in use by this master process and from
158 // all the remote workers executing the remote partitions.
RetrieveLogs(int64_t step_id,StepStats * ss)159 void RetrieveLogs(int64_t step_id, StepStats* ss) {
160 // Get the local data first, because it sets *ss without merging.
161 worker_cache_->RetrieveLogs(step_id, ss);
162
163 // Then merge in data from all the remote workers.
164 LoggingRequest req;
165 req.add_fetch_step_id(step_id);
166 int waiting_for = partitions_.size();
167 if (waiting_for > 0) {
168 mutex scoped_mu;
169 BlockingCounter all_done(waiting_for);
170 for (auto& p : partitions_) {
171 LoggingResponse* resp = new LoggingResponse;
172 p.worker->LoggingAsync(
173 &req, resp,
174 [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
175 {
176 mutex_lock l(scoped_mu);
177 if (s.ok()) {
178 for (auto& lss : resp->step()) {
179 if (step_id != lss.step_id()) {
180 LOG(ERROR) << "Wrong step_id in LoggingResponse";
181 continue;
182 }
183 ss->MergeFrom(lss.step_stats());
184 }
185 }
186 delete resp;
187 }
188 // Must not decrement all_done until out of critical section where
189 // *ss is updated.
190 all_done.DecrementCount();
191 });
192 }
193 all_done.Wait();
194 }
195 }
196
197 // Local execution methods.
198
199 // Partitions the graph into subgraphs and registers them on
200 // workers.
201 Status RegisterPartitions(PartitionOptions popts);
202
203 // Runs one step of all partitions.
204 Status RunPartitions(const MasterEnv* env, int64_t step_id,
205 int64_t execution_count, PerStepState* pss,
206 CallOptions* opts, const RunStepRequestWrapper& req,
207 MutableRunStepResponseWrapper* resp,
208 CancellationManager* cm, const bool is_last_partial_run);
209 Status RunPartitions(const MasterEnv* env, int64_t step_id,
210 int64_t execution_count, PerStepState* pss,
211 CallOptions* call_opts, const RunCallableRequest& req,
212 RunCallableResponse* resp, CancellationManager* cm);
213
214 // Calls workers to cleanup states for the step "step_id". Calls
215 // `done` when all cleanup RPCs have completed.
216 void CleanupPartitionsAsync(int64_t step_id, StatusCallback done);
217
218 // Post-processing of any runtime statistics gathered during execution.
219 void ProcessStats(int64_t step_id, PerStepState* pss, ProfileHandler* ph,
220 const RunOptions& options, RunMetadata* resp);
221 void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
222 bool is_rpc);
223 // Checks that the requested fetches can be computed from the provided feeds.
224 Status CheckFetches(const RunStepRequestWrapper& req,
225 const RunState* run_state,
226 GraphExecutionState* execution_state);
227
228 private:
229 const string session_handle_;
230 const BuildGraphOptions bg_opts_;
231
232 // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
233 std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
234 const SessionOptions session_opts_;
235 const bool is_partial_;
236 const CallableOptions callable_opts_;
237 WorkerCacheInterface* const worker_cache_; // Not owned.
238
239 struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails240 explicit NodeDetails(string type_string, string detail_text)
241 : type_string(std::move(type_string)),
242 detail_text(std::move(detail_text)) {}
243 const string type_string;
244 const string detail_text;
245 };
246 std::unordered_map<string, NodeDetails> name_to_node_details_;
247
248 const bool should_deregister_;
249 const int64 collective_graph_key_;
250 std::atomic<int64> execution_count_ = {0};
251
252 // Graph partitioned into per-location subgraphs.
253 struct Part {
254 // Worker name.
255 string name;
256
257 // Maps feed names to rendezvous keys. Empty most of the time.
258 std::unordered_map<string, string> feed_key;
259
260 // Maps rendezvous keys to fetch names. Empty most of the time.
261 std::unordered_map<string, string> key_fetch;
262
263 // The interface to the worker. Owned.
264 WorkerInterface* worker = nullptr;
265
266 // After registration with the worker, graph_handle identifies
267 // this partition on the worker.
268 string graph_handle;
269
Parttensorflow::MasterSession::ReffedClientGraph::Part270 Part() : feed_key(3), key_fetch(3) {}
271 };
272
273 // partitions_ is immutable after RegisterPartitions() call
274 // finishes. RunPartitions() can access partitions_ safely without
275 // acquiring locks.
276 std::vector<Part> partitions_;
277
278 mutable mutex mu_;
279
280 // Partition initialization and registration only needs to happen
281 // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
282 // indicates the initialization is ongoing.
283 Notification init_done_;
284
285 // init_result_ remembers the initialization error if any.
286 Status init_result_ TF_GUARDED_BY(mu_);
287
288 std::unique_ptr<StatsPublisherInterface> stats_publisher_;
289
DetailText(const NodeDetails & details,const NodeExecStats & stats)290 string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
291 int64_t tot = 0;
292 for (auto& no : stats.output()) {
293 tot += no.tensor_description().allocation_description().requested_bytes();
294 }
295 string bytes;
296 if (tot >= 0.1 * 1048576.0) {
297 bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
298 }
299 return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
300 details.detail_text);
301 }
302
303 // Send/Recv nodes that are the result of client-added
304 // feeds and fetches must be tracked so that the tensors
305 // can be added to the local rendezvous.
306 static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
307 const PartitionOptions& popts);
308
309 // The actual graph partitioning and registration implementation.
310 Status DoBuildPartitions(
311 PartitionOptions popts, ClientGraph* client_graph,
312 std::unordered_map<string, GraphDef>* out_partitions);
313 Status DoRegisterPartitions(
314 const PartitionOptions& popts,
315 std::unordered_map<string, GraphDef> graph_partitions);
316
317 // Prepares a number of calls to workers. One call per partition.
318 // This is a generic method that handles Run, PartialRun, and RunCallable.
319 template <class FetchListType, class ClientRequestType,
320 class ClientResponseType>
321 Status RunPartitionsHelper(
322 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
323 const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
324 int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
325 const ClientRequestType& req, ClientResponseType* resp,
326 CancellationManager* cm, bool is_last_partial_run);
327
328 // Deregisters the partitions on the workers. Called in the
329 // destructor and does not wait for the rpc completion.
330 void DeregisterPartitions();
331
332 TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
333 };
334
RegisterPartitions(PartitionOptions popts)335 Status MasterSession::ReffedClientGraph::RegisterPartitions(
336 PartitionOptions popts) {
337 { // Ensure register once.
338 mu_.lock();
339 if (client_graph_before_register_) {
340 // The `ClientGraph` is no longer needed after partitions are registered.
341 // Since it can account for a large amount of memory, we consume it here,
342 // and it will be freed after concluding with registration.
343
344 std::unique_ptr<ClientGraph> client_graph;
345 std::swap(client_graph_before_register_, client_graph);
346 mu_.unlock();
347 std::unordered_map<string, GraphDef> graph_defs;
348 popts.flib_def = client_graph->flib_def.get();
349 Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
350 if (s.ok()) {
351 // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
352 // valid after the call to DoRegisterPartitions begins, so
353 // `stats_publisher_` must make a copy if it wants to retain the
354 // GraphDef objects.
355 std::vector<const GraphDef*> graph_defs_for_publishing;
356 graph_defs_for_publishing.reserve(partitions_.size());
357 for (const auto& name_def : graph_defs) {
358 graph_defs_for_publishing.push_back(&name_def.second);
359 }
360 stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
361 s = DoRegisterPartitions(popts, std::move(graph_defs));
362 }
363 mu_.lock();
364 init_result_ = s;
365 init_done_.Notify();
366 } else {
367 mu_.unlock();
368 init_done_.WaitForNotification();
369 mu_.lock();
370 }
371 const Status result = init_result_;
372 mu_.unlock();
373 return result;
374 }
375 }
376
SplitByWorker(const Node * node)377 static string SplitByWorker(const Node* node) {
378 string task;
379 string device;
380 CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
381 &device))
382 << "node: " << node->name() << " dev: " << node->assigned_device_name();
383 return task;
384 }
385
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)386 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
387 Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
388 for (int i = 0; i < graph_def.node_size(); ++i) {
389 const NodeDef& ndef = graph_def.node(i);
390 const bool is_recv = ndef.op() == "_Recv";
391 const bool is_send = ndef.op() == "_Send";
392
393 if (is_recv || is_send) {
394 // Only send/recv nodes that were added as feeds and fetches
395 // (client-terminated) should be tracked. Other send/recv nodes
396 // are for transferring data between partitions / memory spaces.
397 bool client_terminated;
398 TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
399 if (client_terminated) {
400 string name;
401 TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
402 string send_device;
403 TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
404 string recv_device;
405 TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
406 uint64 send_device_incarnation;
407 TF_CHECK_OK(
408 GetNodeAttr(ndef, "send_device_incarnation",
409 reinterpret_cast<int64*>(&send_device_incarnation)));
410 const string& key =
411 Rendezvous::CreateKey(send_device, send_device_incarnation,
412 recv_device, name, FrameAndIter(0, 0));
413
414 if (is_recv) {
415 part->feed_key.insert({name, key});
416 } else {
417 part->key_fetch.insert({key, name});
418 }
419 }
420 }
421 }
422 }
423
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)424 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
425 PartitionOptions popts, ClientGraph* client_graph,
426 std::unordered_map<string, GraphDef>* out_partitions) {
427 if (popts.need_to_record_start_times) {
428 CostModel cost_model(true);
429 cost_model.InitFromGraph(client_graph->graph);
430 // TODO(yuanbyu): Use the real cost model.
431 // execution_state_->MergeFromGlobal(&cost_model);
432 SlackAnalysis sa(&client_graph->graph, &cost_model);
433 sa.ComputeAsap(&popts.start_times);
434 }
435
436 // Partition the graph.
437 return Partition(popts, &client_graph->graph, out_partitions);
438 }
439
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)440 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
441 const PartitionOptions& popts,
442 std::unordered_map<string, GraphDef> graph_partitions) {
443 partitions_.reserve(graph_partitions.size());
444 Status s;
445 for (auto& name_def : graph_partitions) {
446 partitions_.emplace_back();
447 Part* part = &partitions_.back();
448 part->name = name_def.first;
449 TrackFeedsAndFetches(part, name_def.second, popts);
450 part->worker = worker_cache_->GetOrCreateWorker(part->name);
451 if (part->worker == nullptr) {
452 s = errors::NotFound("worker ", part->name);
453 break;
454 }
455 }
456 if (!s.ok()) {
457 for (Part& part : partitions_) {
458 worker_cache_->ReleaseWorker(part.name, part.worker);
459 part.worker = nullptr;
460 }
461 return s;
462 }
463 struct Call {
464 RegisterGraphRequest req;
465 RegisterGraphResponse resp;
466 Status status;
467 };
468 const int num = partitions_.size();
469 gtl::InlinedVector<Call, 4> calls(num);
470 BlockingCounter done(num);
471 for (int i = 0; i < num; ++i) {
472 const Part& part = partitions_[i];
473 Call* c = &calls[i];
474 c->req.set_session_handle(session_handle_);
475 c->req.set_create_worker_session_called(!should_deregister_);
476 c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
477 StripDefaultAttributes(*OpRegistry::Global(),
478 c->req.mutable_graph_def()->mutable_node());
479 *c->req.mutable_config_proto() = session_opts_.config;
480 *c->req.mutable_graph_options() = session_opts_.config.graph_options();
481 *c->req.mutable_debug_options() =
482 callable_opts_.run_options().debug_options();
483 c->req.set_collective_graph_key(collective_graph_key_);
484 VLOG(2) << "Register " << c->req.graph_def().DebugString();
485 auto cb = [c, &done](const Status& s) {
486 c->status = s;
487 done.DecrementCount();
488 };
489 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
490 }
491 done.Wait();
492 for (int i = 0; i < num; ++i) {
493 Call* c = &calls[i];
494 s.Update(c->status);
495 partitions_[i].graph_handle = c->resp.graph_handle();
496 }
497 return s;
498 }
499
500 namespace {
501 // Helper class to manage "num" parallel RunGraph calls.
502 class RunManyGraphs {
503 public:
RunManyGraphs(int num)504 explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
505
~RunManyGraphs()506 ~RunManyGraphs() {}
507
508 // Returns the index-th call.
509 struct Call {
510 CallOptions opts;
511 const string* worker_name;
512 std::atomic<bool> done{false};
513 std::unique_ptr<MutableRunGraphRequestWrapper> req;
514 std::unique_ptr<MutableRunGraphResponseWrapper> resp;
515 };
get(int index)516 Call* get(int index) { return &calls_[index]; }
517
518 // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)519 void WhenDone(int index, const Status& s) {
520 TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
521 Call* call = get(index);
522 call->done = true;
523 auto resp = call->resp.get();
524 if (resp->status_code() != error::Code::OK) {
525 // resp->status_code will only be non-OK if s.ok().
526 mutex_lock l(mu_);
527 ReportBadStatus(Status(resp->status_code(),
528 strings::StrCat("From ", *call->worker_name, ":\n",
529 resp->status_error_message())));
530 } else if (!s.ok()) {
531 mutex_lock l(mu_);
532 ReportBadStatus(
533 Status(s.code(), strings::StrCat("From ", *call->worker_name, ":\n",
534 s.error_message())));
535 }
536 pending_.DecrementCount();
537 }
538
StartCancel()539 void StartCancel() {
540 mutex_lock l(mu_);
541 ReportBadStatus(errors::Cancelled("RunManyGraphs"));
542 }
543
Wait()544 void Wait() {
545 // Check the error status every 60 seconds in other to print a log message
546 // in the event of a hang.
547 const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
548 while (true) {
549 if (pending_.WaitFor(kCheckErrorPeriod)) {
550 return;
551 }
552 if (!status().ok()) {
553 break;
554 }
555 }
556
557 // The step has failed. Wait for another 60 seconds before diagnosing a
558 // hang.
559 DCHECK(!status().ok());
560 if (pending_.WaitFor(kCheckErrorPeriod)) {
561 return;
562 }
563 LOG(ERROR)
564 << "RunStep still blocked after 60 seconds. Failed with error status: "
565 << status();
566 for (const Call& call : calls_) {
567 if (!call.done) {
568 LOG(ERROR) << "- No response from RunGraph call to worker: "
569 << *call.worker_name;
570 }
571 }
572 pending_.Wait();
573 }
574
status() const575 Status status() const {
576 mutex_lock l(mu_);
577 // Concat status objects in this StatusGroup to get the aggregated status,
578 // as each status in status_group_ is already summarized status.
579 return status_group_.as_concatenated_status();
580 }
581
582 private:
583 gtl::InlinedVector<Call, 4> calls_;
584
585 BlockingCounter pending_;
586 mutable mutex mu_;
587 StatusGroup status_group_ TF_GUARDED_BY(mu_);
588 bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
589
ReportBadStatus(const Status & s)590 void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
591 VLOG(1) << "Master received error status " << s;
592 if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
593 // Only start cancelling other workers upon receiving a non-derived
594 // error
595 cancel_issued_ = true;
596
597 VLOG(1) << "Master received error report. Cancelling remaining workers.";
598 for (Call& call : calls_) {
599 call.opts.StartCancel();
600 }
601 }
602
603 status_group_.Update(s);
604 }
605
606 TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
607 };
608
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)609 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
610 MutableRunGraphRequestWrapper* worker_req,
611 size_t index, const string& send_key) {
612 return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
613 }
614
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)615 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
616 MutableRunGraphRequestWrapper* worker_req,
617 size_t index, const string& send_key) {
618 return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
619 }
620
621 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
622 // in-process messages.
623 struct RunCallableResponseWrapper {
624 RunCallableResponse* resp; // Not owned.
625 std::unordered_map<string, TensorProto> fetch_key_to_protos;
626
mutable_metadatatensorflow::__anonbe2b6c110411::RunCallableResponseWrapper627 RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
628
AddTensorFromRunGraphResponsetensorflow::__anonbe2b6c110411::RunCallableResponseWrapper629 Status AddTensorFromRunGraphResponse(
630 const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
631 size_t index) {
632 return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
633 }
634 };
635 } // namespace
636
637 template <class FetchListType, class ClientRequestType,
638 class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)639 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
640 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
641 const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
642 int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
643 const ClientRequestType& req, ClientResponseType* resp,
644 CancellationManager* cm, bool is_last_partial_run) {
645 // Collect execution cost stats on a smoothly decreasing frequency.
646 ExecutorOpts exec_opts;
647 if (pss->report_tensor_allocations_upon_oom) {
648 exec_opts.set_report_tensor_allocations_upon_oom(true);
649 }
650 if (pss->collect_costs) {
651 exec_opts.set_record_costs(true);
652 }
653 if (pss->collect_timeline) {
654 exec_opts.set_record_timeline(true);
655 }
656 if (pss->collect_rpcs) {
657 SetRPCLogging(true);
658 }
659 if (pss->collect_partition_graphs) {
660 exec_opts.set_record_partition_graphs(true);
661 }
662 if (pss->collect_costs || pss->collect_timeline) {
663 pss->step_stats.resize(partitions_.size());
664 }
665
666 const int num = partitions_.size();
667 RunManyGraphs calls(num);
668
669 for (int i = 0; i < num; ++i) {
670 const Part& part = partitions_[i];
671 RunManyGraphs::Call* c = calls.get(i);
672 c->worker_name = &part.name;
673 c->req.reset(part.worker->CreateRunGraphRequest());
674 c->resp.reset(part.worker->CreateRunGraphResponse());
675 if (is_partial_) {
676 c->req->set_is_partial(is_partial_);
677 c->req->set_is_last_partial_run(is_last_partial_run);
678 }
679 c->req->set_session_handle(session_handle_);
680 c->req->set_create_worker_session_called(!should_deregister_);
681 c->req->set_graph_handle(part.graph_handle);
682 c->req->set_step_id(step_id);
683 *c->req->mutable_exec_opts() = exec_opts;
684 c->req->set_store_errors_in_response_body(true);
685 c->req->set_request_id(GetUniqueRequestId());
686 // If any feeds are provided, send the feed values together
687 // in the RunGraph request.
688 // In the partial case, we only want to include feeds provided in the req.
689 // In the non-partial case, all feeds in the request are in the part.
690 // We keep these as separate paths for now, to ensure we aren't
691 // inadvertently slowing down the normal run path.
692 if (is_partial_) {
693 for (const auto& name_index : feeds) {
694 const auto iter = part.feed_key.find(string(name_index.first));
695 if (iter == part.feed_key.end()) {
696 // The provided feed must be for a different partition.
697 continue;
698 }
699 const string& key = iter->second;
700 TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
701 name_index.second, key));
702 }
703 // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
704 // For now, we just iterate through partitions to find the matching key.
705 for (const string& req_fetch : fetches) {
706 for (const auto& key_fetch : part.key_fetch) {
707 if (key_fetch.second == req_fetch) {
708 c->req->add_recv_key(key_fetch.first);
709 break;
710 }
711 }
712 }
713 } else {
714 for (const auto& feed_key : part.feed_key) {
715 const string& feed = feed_key.first;
716 const string& key = feed_key.second;
717 auto iter = feeds.find(feed);
718 if (iter == feeds.end()) {
719 return errors::Internal("No feed index found for feed: ", feed);
720 }
721 const int64_t feed_index = iter->second;
722 TF_RETURN_IF_ERROR(
723 AddSendFromClientRequest(req, c->req.get(), feed_index, key));
724 }
725 for (const auto& key_fetch : part.key_fetch) {
726 const string& key = key_fetch.first;
727 c->req->add_recv_key(key);
728 }
729 }
730 }
731
732 // Issues RunGraph calls.
733 for (int i = 0; i < num; ++i) {
734 const Part& part = partitions_[i];
735 RunManyGraphs::Call* call = calls.get(i);
736 TRACEPRINTF("Partition %d %s", i, part.name.c_str());
737 part.worker->RunGraphAsync(
738 &call->opts, call->req.get(), call->resp.get(),
739 std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
740 }
741
742 // Waits for the RunGraph calls.
743 call_opts->SetCancelCallback([&calls]() {
744 LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
745 "worker operations.";
746 calls.StartCancel();
747 });
748 auto token = cm->get_cancellation_token();
749 const bool success =
750 cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
751 if (!success) {
752 calls.StartCancel();
753 }
754 calls.Wait();
755 call_opts->ClearCancelCallback();
756 if (success) {
757 cm->DeregisterCallback(token);
758 } else {
759 return errors::Cancelled("Step was cancelled");
760 }
761 TF_RETURN_IF_ERROR(calls.status());
762
763 // Collects fetches and metadata.
764 Status status;
765 for (int i = 0; i < num; ++i) {
766 const Part& part = partitions_[i];
767 MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
768 for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
769 auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
770 if (iter == part.key_fetch.end()) {
771 status.Update(errors::Internal("Unexpected fetch key: ",
772 run_graph_resp->recv_key(j)));
773 break;
774 }
775 const string& fetch = iter->second;
776 status.Update(
777 resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
778 if (!status.ok()) {
779 break;
780 }
781 }
782 if (pss->collect_timeline) {
783 pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
784 }
785 if (pss->collect_costs) {
786 CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
787 for (int j = 0; j < cost_graph->node_size(); ++j) {
788 resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
789 cost_graph->mutable_node(j));
790 }
791 }
792 if (pss->collect_partition_graphs) {
793 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
794 resp->mutable_metadata()->mutable_partition_graphs();
795 for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
796 partition_graph_defs->Add()->Swap(
797 run_graph_resp->mutable_partition_graph(i));
798 }
799 }
800 }
801 return status;
802 }
803
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)804 Status MasterSession::ReffedClientGraph::RunPartitions(
805 const MasterEnv* env, int64_t step_id, int64_t execution_count,
806 PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
807 MutableRunStepResponseWrapper* resp, CancellationManager* cm,
808 const bool is_last_partial_run) {
809 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
810 << execution_count;
811 // Maps the names of fed tensors to their index in `req`.
812 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
813 for (size_t i = 0; i < req.num_feeds(); ++i) {
814 if (!feeds.insert({req.feed_name(i), i}).second) {
815 return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
816 }
817 }
818
819 std::vector<string> fetches;
820 fetches.reserve(req.num_fetches());
821 for (size_t i = 0; i < req.num_fetches(); ++i) {
822 fetches.push_back(req.fetch_name(i));
823 }
824
825 return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
826 call_opts, req, resp, cm, is_last_partial_run);
827 }
828
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)829 Status MasterSession::ReffedClientGraph::RunPartitions(
830 const MasterEnv* env, int64_t step_id, int64_t execution_count,
831 PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
832 RunCallableResponse* resp, CancellationManager* cm) {
833 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
834 << execution_count;
835 // Maps the names of fed tensors to their index in `req`.
836 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
837 for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
838 if (!feeds.insert({callable_opts_.feed(i), i}).second) {
839 // MakeCallable will fail if there are two feeds with the same name.
840 return errors::Internal("Duplicated feeds in callable: ",
841 callable_opts_.feed(i));
842 }
843 }
844
845 // Create a wrapped response object to collect the fetched values and
846 // rearrange them for the RunCallableResponse.
847 RunCallableResponseWrapper wrapped_resp;
848 wrapped_resp.resp = resp;
849
850 TF_RETURN_IF_ERROR(RunPartitionsHelper(
851 feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
852 call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
853
854 // Collects fetches.
855 for (const string& fetch : callable_opts_.fetch()) {
856 TensorProto* fetch_proto = resp->mutable_fetch()->Add();
857 auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
858 if (iter == wrapped_resp.fetch_key_to_protos.end()) {
859 return errors::Internal("Worker did not return a value for fetch: ",
860 fetch);
861 }
862 fetch_proto->Swap(&iter->second);
863 }
864 return Status::OK();
865 }
866
867 namespace {
868
869 class CleanupBroadcastHelper {
870 public:
CleanupBroadcastHelper(int64_t step_id,int num_calls,StatusCallback done)871 CleanupBroadcastHelper(int64_t step_id, int num_calls, StatusCallback done)
872 : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
873 req_.set_step_id(step_id);
874 }
875
876 // Returns a non-owned pointer to a request buffer for all calls.
request()877 CleanupGraphRequest* request() { return &req_; }
878
879 // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)880 CleanupGraphResponse* response(int i) { return &resps_[i]; }
881
882 // Called when the ith response is received.
call_done(int i,const Status & s)883 void call_done(int i, const Status& s) {
884 bool run_callback = false;
885 Status status_copy;
886 {
887 mutex_lock l(mu_);
888 status_.Update(s);
889 if (--num_pending_ == 0) {
890 run_callback = true;
891 status_copy = status_;
892 }
893 }
894 if (run_callback) {
895 done_(status_copy);
896 // This is the last call, so delete the helper object.
897 delete this;
898 }
899 }
900
901 private:
902 // A single request shared between all workers.
903 CleanupGraphRequest req_;
904 // One response buffer for each worker.
905 gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
906
907 mutex mu_;
908 // Number of requests remaining to be collected.
909 int num_pending_ TF_GUARDED_BY(mu_);
910 // Aggregate status of the operation.
911 Status status_ TF_GUARDED_BY(mu_);
912 // Callback to be called when all operations complete.
913 StatusCallback done_;
914
915 TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
916 };
917
918 } // namespace
919
CleanupPartitionsAsync(int64_t step_id,StatusCallback done)920 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
921 int64_t step_id, StatusCallback done) {
922 const int num = partitions_.size();
923 // Helper object will be deleted when the final call completes.
924 CleanupBroadcastHelper* helper =
925 new CleanupBroadcastHelper(step_id, num, std::move(done));
926 for (int i = 0; i < num; ++i) {
927 const Part& part = partitions_[i];
928 part.worker->CleanupGraphAsync(
929 helper->request(), helper->response(i),
930 [helper, i](const Status& s) { helper->call_done(i, s); });
931 }
932 }
933
ProcessStats(int64_t step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)934 void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id,
935 PerStepState* pss,
936 ProfileHandler* ph,
937 const RunOptions& options,
938 RunMetadata* resp) {
939 if (!pss->collect_costs && !pss->collect_timeline) return;
940
941 // Out-of-band logging data is collected now, during post-processing.
942 if (pss->collect_timeline) {
943 SetRPCLogging(false);
944 RetrieveLogs(step_id, &pss->rpc_stats);
945 }
946 for (size_t i = 0; i < partitions_.size(); ++i) {
947 const StepStats& ss = pss->step_stats[i];
948 if (ph) {
949 for (const auto& ds : ss.dev_stats()) {
950 ProcessDeviceStats(ph, ds, false /*is_rpc*/);
951 }
952 }
953 }
954 if (ph) {
955 for (const auto& ds : pss->rpc_stats.dev_stats()) {
956 ProcessDeviceStats(ph, ds, true /*is_rpc*/);
957 }
958 ph->StepDone(pss->start_micros, pss->end_micros,
959 Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
960 Status::OK());
961 }
962 // Assemble all stats for this timeline into a merged StepStats.
963 if (pss->collect_timeline) {
964 StepStats step_stats_proto;
965 step_stats_proto.Swap(&pss->rpc_stats);
966 for (size_t i = 0; i < partitions_.size(); ++i) {
967 step_stats_proto.MergeFrom(pss->step_stats[i]);
968 pss->step_stats[i].Clear();
969 }
970 pss->step_stats.clear();
971 // Copy the stats back, but only for on-demand profiling to avoid slowing
972 // down calls that trigger the automatic profiling.
973 if (options.trace_level() == RunOptions::FULL_TRACE) {
974 resp->mutable_step_stats()->Swap(&step_stats_proto);
975 } else {
976 // If FULL_TRACE, it can be fetched from Session API, no need for
977 // duplicated publishing.
978 stats_publisher_->PublishStatsProto(step_stats_proto);
979 }
980 }
981 }
982
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)983 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
984 ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
985 const string& dev_name = ds.device();
986 VLOG(1) << "Device " << dev_name << " reports stats for "
987 << ds.node_stats_size() << " nodes";
988 for (const auto& ns : ds.node_stats()) {
989 if (is_rpc) {
990 // We don't have access to a good Node pointer, so we rely on
991 // sufficient data being present in the NodeExecStats.
992 ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
993 ns.timeline_label());
994 } else {
995 auto iter = name_to_node_details_.find(ns.node_name());
996 const bool found_node_in_graph = iter != name_to_node_details_.end();
997 if (!found_node_in_graph && ns.timeline_label().empty()) {
998 // The counter incrementing is not thread-safe. But we don't really
999 // care.
1000 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1001 // more general usage.
1002 static int log_counter = 0;
1003 if (log_counter < 10) {
1004 log_counter++;
1005 LOG(WARNING) << "Failed to find node " << ns.node_name()
1006 << " for dev " << dev_name;
1007 }
1008 continue;
1009 }
1010 const string& optype =
1011 found_node_in_graph ? iter->second.type_string : ns.node_name();
1012 string details;
1013 if (!ns.timeline_label().empty()) {
1014 details = ns.timeline_label();
1015 } else if (found_node_in_graph) {
1016 details = DetailText(iter->second, ns);
1017 } else {
1018 // Leave details string empty
1019 }
1020 ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1021 details);
1022 }
1023 }
1024 }
1025
1026 // TODO(suharshs): Merge with CheckFetches in DirectSession.
1027 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1028 // on once at setup time to prevent us from computing the dependencies
1029 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)1030 Status MasterSession::ReffedClientGraph::CheckFetches(
1031 const RunStepRequestWrapper& req, const RunState* run_state,
1032 GraphExecutionState* execution_state) {
1033 // Build the set of pending feeds that we haven't seen.
1034 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1035 for (const auto& input : run_state->pending_inputs) {
1036 // Skip if already fed.
1037 if (input.second) continue;
1038 TensorId id(ParseTensorName(input.first));
1039 const Node* n = execution_state->get_node_by_name(string(id.first));
1040 if (n == nullptr) {
1041 return errors::NotFound("Feed ", input.first, ": not found");
1042 }
1043 pending_feeds.insert(id);
1044 }
1045 for (size_t i = 0; i < req.num_feeds(); ++i) {
1046 const TensorId id(ParseTensorName(req.feed_name(i)));
1047 pending_feeds.erase(id);
1048 }
1049
1050 // Initialize the stack with the fetch nodes.
1051 std::vector<const Node*> stack;
1052 for (size_t i = 0; i < req.num_fetches(); ++i) {
1053 const string& fetch = req.fetch_name(i);
1054 const TensorId id(ParseTensorName(fetch));
1055 const Node* n = execution_state->get_node_by_name(string(id.first));
1056 if (n == nullptr) {
1057 return errors::NotFound("Fetch ", fetch, ": not found");
1058 }
1059 stack.push_back(n);
1060 }
1061
1062 // Any tensor needed for fetches can't be in pending_feeds.
1063 // We need to use the original full graph from execution state.
1064 const Graph* graph = execution_state->full_graph();
1065 std::vector<bool> visited(graph->num_node_ids(), false);
1066 while (!stack.empty()) {
1067 const Node* n = stack.back();
1068 stack.pop_back();
1069
1070 for (const Edge* in_edge : n->in_edges()) {
1071 const Node* in_node = in_edge->src();
1072 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1073 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1074 in_edge->src_output(),
1075 " can't be computed from the feeds"
1076 " that have been fed so far.");
1077 }
1078 if (!visited[in_node->id()]) {
1079 visited[in_node->id()] = true;
1080 stack.push_back(in_node);
1081 }
1082 }
1083 }
1084 return Status::OK();
1085 }
1086
1087 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1088 // result.
DeregisterPartitions()1089 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1090 struct Call {
1091 DeregisterGraphRequest req;
1092 DeregisterGraphResponse resp;
1093 };
1094 for (Part& part : partitions_) {
1095 // The graph handle may be empty if we failed during partition registration.
1096 if (!part.graph_handle.empty()) {
1097 Call* c = new Call;
1098 c->req.set_session_handle(session_handle_);
1099 c->req.set_create_worker_session_called(!should_deregister_);
1100 c->req.set_graph_handle(part.graph_handle);
1101 // NOTE(mrry): We must capture `worker_cache_` since `this`
1102 // could be deleted before the callback is called.
1103 WorkerCacheInterface* worker_cache = worker_cache_;
1104 const string name = part.name;
1105 WorkerInterface* w = part.worker;
1106 CHECK_NOTNULL(w);
1107 auto cb = [worker_cache, c, name, w](const Status& s) {
1108 if (!s.ok()) {
1109 // This error is potentially benign, so we don't log at the
1110 // error level.
1111 LOG(INFO) << "DeregisterGraph error: " << s;
1112 }
1113 delete c;
1114 worker_cache->ReleaseWorker(name, w);
1115 };
1116 w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1117 }
1118 }
1119 }
1120
1121 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1122 void CopyAndSortStrings(size_t size,
1123 const std::function<string(size_t)>& input_accessor,
1124 protobuf::RepeatedPtrField<string>* output) {
1125 std::vector<string> temp;
1126 temp.reserve(size);
1127 for (size_t i = 0; i < size; ++i) {
1128 output->Add(input_accessor(i));
1129 }
1130 std::sort(output->begin(), output->end());
1131 }
1132 } // namespace
1133
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1134 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1135 const ConfigProto& config,
1136 BuildGraphOptions* opts) {
1137 CallableOptions* callable_opts = &opts->callable_options;
1138 CopyAndSortStrings(
1139 req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1140 callable_opts->mutable_feed());
1141 CopyAndSortStrings(
1142 req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1143 callable_opts->mutable_fetch());
1144 CopyAndSortStrings(
1145 req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1146 callable_opts->mutable_target());
1147
1148 if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1149 *callable_opts->mutable_run_options()->mutable_debug_options() =
1150 req.options().debug_options();
1151 }
1152
1153 opts->collective_graph_key =
1154 req.options().experimental().collective_graph_key();
1155 if (config.experimental().collective_deterministic_sequential_execution()) {
1156 opts->collective_order = GraphCollectiveOrder::kEdges;
1157 } else if (config.experimental().collective_nccl()) {
1158 opts->collective_order = GraphCollectiveOrder::kAttrs;
1159 }
1160 }
1161
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1162 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1163 BuildGraphOptions* opts) {
1164 CallableOptions* callable_opts = &opts->callable_options;
1165 CopyAndSortStrings(
1166 req.feed_size(), [&req](size_t i) { return req.feed(i); },
1167 callable_opts->mutable_feed());
1168 CopyAndSortStrings(
1169 req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1170 callable_opts->mutable_fetch());
1171 CopyAndSortStrings(
1172 req.target_size(), [&req](size_t i) { return req.target(i); },
1173 callable_opts->mutable_target());
1174
1175 // TODO(cais): Add TFDBG support to partial runs.
1176 }
1177
HashBuildGraphOptions(const BuildGraphOptions & opts)1178 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1179 uint64 h = 0x2b992ddfa23249d6ull;
1180 for (const string& name : opts.callable_options.feed()) {
1181 h = Hash64(name.c_str(), name.size(), h);
1182 }
1183 for (const string& name : opts.callable_options.target()) {
1184 h = Hash64(name.c_str(), name.size(), h);
1185 }
1186 for (const string& name : opts.callable_options.fetch()) {
1187 h = Hash64(name.c_str(), name.size(), h);
1188 }
1189
1190 const DebugOptions& debug_options =
1191 opts.callable_options.run_options().debug_options();
1192 if (!debug_options.debug_tensor_watch_opts().empty()) {
1193 const string watch_summary =
1194 SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1195 h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1196 }
1197
1198 return h;
1199 }
1200
BuildGraphOptionsString(const BuildGraphOptions & opts)1201 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1202 string buf;
1203 for (const string& name : opts.callable_options.feed()) {
1204 strings::StrAppend(&buf, " FdE: ", name);
1205 }
1206 strings::StrAppend(&buf, "\n");
1207 for (const string& name : opts.callable_options.target()) {
1208 strings::StrAppend(&buf, " TN: ", name);
1209 }
1210 strings::StrAppend(&buf, "\n");
1211 for (const string& name : opts.callable_options.fetch()) {
1212 strings::StrAppend(&buf, " FeE: ", name);
1213 }
1214 if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1215 strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1216 }
1217 strings::StrAppend(&buf, "\n");
1218 return buf;
1219 }
1220
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1221 MasterSession::MasterSession(
1222 const SessionOptions& opt, const MasterEnv* env,
1223 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1224 std::unique_ptr<WorkerCacheInterface> worker_cache,
1225 std::unique_ptr<DeviceSet> device_set,
1226 std::vector<string> filtered_worker_list,
1227 StatsPublisherFactory stats_publisher_factory)
1228 : session_opts_(opt),
1229 env_(env),
1230 handle_(strings::FpToString(random::New64())),
1231 remote_devs_(std::move(remote_devs)),
1232 worker_cache_(std::move(worker_cache)),
1233 devices_(std::move(device_set)),
1234 filtered_worker_list_(std::move(filtered_worker_list)),
1235 stats_publisher_factory_(std::move(stats_publisher_factory)),
1236 graph_version_(0),
1237 run_graphs_(5),
1238 partial_run_graphs_(5) {
1239 UpdateLastAccessTime();
1240 CHECK(devices_) << "device_set was null!";
1241
1242 VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1243 << " #remote " << remote_devs_->size();
1244 VLOG(1) << "Start master session " << handle_
1245 << " with config: " << session_opts_.config.ShortDebugString();
1246 }
1247
~MasterSession()1248 MasterSession::~MasterSession() {
1249 for (const auto& iter : run_graphs_) iter.second->Unref();
1250 for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1251 }
1252
UpdateLastAccessTime()1253 void MasterSession::UpdateLastAccessTime() {
1254 last_access_time_usec_.store(Env::Default()->NowMicros());
1255 }
1256
Create(GraphDef && graph_def,const WorkerCacheFactoryOptions & options)1257 Status MasterSession::Create(GraphDef&& graph_def,
1258 const WorkerCacheFactoryOptions& options) {
1259 if (session_opts_.config.use_per_session_threads() ||
1260 session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1261 return errors::InvalidArgument(
1262 "Distributed session does not support session thread pool options.");
1263 }
1264 if (session_opts_.config.graph_options().place_pruned_graph()) {
1265 // TODO(b/29900832): Fix this or remove the option.
1266 LOG(WARNING) << "Distributed session does not support the "
1267 "place_pruned_graph option.";
1268 session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1269 }
1270
1271 GraphExecutionStateOptions execution_options;
1272 execution_options.device_set = devices_.get();
1273 execution_options.session_options = &session_opts_;
1274 {
1275 mutex_lock l(mu_);
1276 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1277 std::move(graph_def), execution_options, &execution_state_));
1278 }
1279 should_delete_worker_sessions_ = true;
1280 return CreateWorkerSessions(options);
1281 }
1282
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1283 Status MasterSession::CreateWorkerSessions(
1284 const WorkerCacheFactoryOptions& options) {
1285 const std::vector<string> worker_names = filtered_worker_list_;
1286 WorkerCacheInterface* worker_cache = get_worker_cache();
1287
1288 struct WorkerGroup {
1289 // The worker name. (Not owned.)
1290 const string* name;
1291
1292 // The worker referenced by name. (Not owned.)
1293 WorkerInterface* worker = nullptr;
1294
1295 // Request and responses used for a given worker.
1296 CreateWorkerSessionRequest request;
1297 CreateWorkerSessionResponse response;
1298 Status status = Status::OK();
1299 };
1300 BlockingCounter done(worker_names.size());
1301 std::vector<WorkerGroup> workers(worker_names.size());
1302
1303 // Release the workers.
1304 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1305 for (auto&& worker_group : workers) {
1306 if (worker_group.worker != nullptr) {
1307 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1308 }
1309 }
1310 });
1311
1312 string task_name;
1313 string local_device_name;
1314 DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1315 &task_name, &local_device_name);
1316 const int64_t client_device_incarnation =
1317 devices_->client_device()->attributes().incarnation();
1318
1319 Status status = Status::OK();
1320 // Create all the workers & kick off the computations.
1321 for (size_t i = 0; i < worker_names.size(); ++i) {
1322 workers[i].name = &worker_names[i];
1323 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1324 workers[i].request.set_session_handle(handle_);
1325 workers[i].request.set_master_task(task_name);
1326 workers[i].request.set_master_incarnation(client_device_incarnation);
1327 if (session_opts_.config.share_cluster_devices_in_session() ||
1328 session_opts_.config.experimental()
1329 .share_cluster_devices_in_session()) {
1330 for (const auto& remote_dev : devices_->devices()) {
1331 *workers[i].request.add_cluster_device_attributes() =
1332 remote_dev->attributes();
1333 }
1334
1335 if (!session_opts_.config.share_cluster_devices_in_session() &&
1336 session_opts_.config.experimental()
1337 .share_cluster_devices_in_session()) {
1338 LOG(WARNING)
1339 << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1340 "been promoted to a non-experimental API. Please use "
1341 "ConfigProto.share_cluster_devices_in_session instead. The "
1342 "experimental option will be removed in the future.";
1343 }
1344 }
1345
1346 DeviceNameUtils::ParsedName name;
1347 if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1348 status = errors::Internal("Could not parse name ", worker_names[i]);
1349 LOG(WARNING) << status;
1350 return status;
1351 }
1352 if (!name.has_job || !name.has_task) {
1353 status = errors::Internal("Incomplete worker name ", worker_names[i]);
1354 LOG(WARNING) << status;
1355 return status;
1356 }
1357
1358 if (options.cluster_def) {
1359 *workers[i].request.mutable_server_def()->mutable_cluster() =
1360 *options.cluster_def;
1361 workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1362 workers[i].request.mutable_server_def()->set_job_name(name.job);
1363 workers[i].request.mutable_server_def()->set_task_index(name.task);
1364 // Session state is always isolated when ClusterSpec propagation
1365 // is in use.
1366 workers[i].request.set_isolate_session_state(true);
1367 } else {
1368 // NOTE(mrry): Do not set any component of the ServerDef,
1369 // because the worker will use its local configuration.
1370 workers[i].request.set_isolate_session_state(
1371 session_opts_.config.isolate_session_state());
1372 }
1373 if (session_opts_.config.experimental()
1374 .share_session_state_in_clusterspec_propagation()) {
1375 // In a dynamic cluster, the ClusterSpec info is usually propagated by
1376 // master sessions. However, in data parallel training with multiple
1377 // masters
1378 // ("between-graph replication"), we need to disable isolation for
1379 // different worker sessions to update the same variables in PS tasks.
1380 workers[i].request.set_isolate_session_state(false);
1381 }
1382 }
1383
1384 for (size_t i = 0; i < worker_names.size(); ++i) {
1385 auto cb = [i, &workers, &done](const Status& s) {
1386 workers[i].status = s;
1387 done.DecrementCount();
1388 };
1389 workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1390 &workers[i].response, cb);
1391 }
1392
1393 done.Wait();
1394 for (size_t i = 0; i < workers.size(); ++i) {
1395 status.Update(workers[i].status);
1396 }
1397 return status;
1398 }
1399
DeleteWorkerSessions()1400 Status MasterSession::DeleteWorkerSessions() {
1401 WorkerCacheInterface* worker_cache = get_worker_cache();
1402 const std::vector<string>& worker_names = filtered_worker_list_;
1403
1404 struct WorkerGroup {
1405 // The worker name. (Not owned.)
1406 const string* name;
1407
1408 // The worker referenced by name. (Not owned.)
1409 WorkerInterface* worker = nullptr;
1410
1411 CallOptions call_opts;
1412
1413 // Request and responses used for a given worker.
1414 DeleteWorkerSessionRequest request;
1415 DeleteWorkerSessionResponse response;
1416 Status status = Status::OK();
1417 };
1418 BlockingCounter done(worker_names.size());
1419 std::vector<WorkerGroup> workers(worker_names.size());
1420
1421 // Release the workers.
1422 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1423 for (auto&& worker_group : workers) {
1424 if (worker_group.worker != nullptr) {
1425 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1426 }
1427 }
1428 });
1429
1430 Status status = Status::OK();
1431 // Create all the workers & kick off the computations.
1432 for (size_t i = 0; i < worker_names.size(); ++i) {
1433 workers[i].name = &worker_names[i];
1434 workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1435 workers[i].request.set_session_handle(handle_);
1436 // Since the worker may have gone away, set a timeout to avoid blocking the
1437 // session-close operation.
1438 workers[i].call_opts.SetTimeout(10000);
1439 }
1440
1441 for (size_t i = 0; i < worker_names.size(); ++i) {
1442 auto cb = [i, &workers, &done](const Status& s) {
1443 workers[i].status = s;
1444 done.DecrementCount();
1445 };
1446 workers[i].worker->DeleteWorkerSessionAsync(
1447 &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1448 }
1449
1450 done.Wait();
1451 for (size_t i = 0; i < workers.size(); ++i) {
1452 status.Update(workers[i].status);
1453 }
1454 return status;
1455 }
1456
ListDevices(ListDevicesResponse * resp) const1457 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1458 if (worker_cache_) {
1459 // This is a ClusterSpec-propagated session, and thus env_->local_devices
1460 // are invalid.
1461
1462 // Mark the "client_device" as the sole local device.
1463 const Device* client_device = devices_->client_device();
1464 for (const Device* dev : devices_->devices()) {
1465 if (dev != client_device) {
1466 *(resp->add_remote_device()) = dev->attributes();
1467 }
1468 }
1469 *(resp->add_local_device()) = client_device->attributes();
1470 } else {
1471 for (Device* dev : env_->local_devices) {
1472 *(resp->add_local_device()) = dev->attributes();
1473 }
1474 for (auto&& dev : *remote_devs_) {
1475 *(resp->add_local_device()) = dev->attributes();
1476 }
1477 }
1478 return Status::OK();
1479 }
1480
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1481 Status MasterSession::Extend(const ExtendSessionRequest* req,
1482 ExtendSessionResponse* resp) {
1483 UpdateLastAccessTime();
1484 std::unique_ptr<GraphExecutionState> extended_execution_state;
1485 {
1486 mutex_lock l(mu_);
1487 if (closed_) {
1488 return errors::FailedPrecondition("Session is closed.");
1489 }
1490
1491 if (graph_version_ != req->current_graph_version()) {
1492 return errors::Aborted("Current version is ", graph_version_,
1493 " but caller expected ",
1494 req->current_graph_version(), ".");
1495 }
1496
1497 CHECK(execution_state_);
1498 TF_RETURN_IF_ERROR(
1499 execution_state_->Extend(req->graph_def(), &extended_execution_state));
1500
1501 CHECK(extended_execution_state);
1502 // The old execution state will be released outside the lock.
1503 execution_state_.swap(extended_execution_state);
1504 ++graph_version_;
1505 resp->set_new_graph_version(graph_version_);
1506 }
1507 return Status::OK();
1508 }
1509
get_worker_cache() const1510 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1511 if (worker_cache_) {
1512 return worker_cache_.get();
1513 }
1514 return env_->worker_cache;
1515 }
1516
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1517 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1518 ReffedClientGraph** out_rcg, int64* out_count) {
1519 const uint64 hash = HashBuildGraphOptions(opts);
1520 {
1521 mutex_lock l(mu_);
1522 // TODO(suharshs): We cache partial run graphs and run graphs separately
1523 // because there is preprocessing that needs to only be run for partial
1524 // run calls.
1525 RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1526 auto iter = m->find(hash);
1527 if (iter == m->end()) {
1528 // We have not seen this subgraph before. Build the subgraph and
1529 // cache it.
1530 VLOG(1) << "Unseen hash " << hash << " for "
1531 << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1532 << "\n";
1533 std::unique_ptr<ClientGraph> client_graph;
1534 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1535 WorkerCacheInterface* worker_cache = get_worker_cache();
1536 auto entry = new ReffedClientGraph(
1537 handle_, opts, std::move(client_graph), session_opts_,
1538 stats_publisher_factory_, is_partial, worker_cache,
1539 !should_delete_worker_sessions_);
1540 iter = m->insert({hash, entry}).first;
1541 VLOG(1) << "Preparing to execute new graph";
1542 }
1543 *out_rcg = iter->second;
1544 (*out_rcg)->Ref();
1545 *out_count = (*out_rcg)->get_and_increment_execution_count();
1546 }
1547 return Status::OK();
1548 }
1549
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1550 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1551 RCGMap* rcg_map) {
1552 VLOG(1) << "Discarding all reffed graphs";
1553 for (auto p : *rcg_map) {
1554 ReffedClientGraph* rcg = p.second;
1555 if (to_unref) {
1556 to_unref->push_back(rcg);
1557 } else {
1558 rcg->Unref();
1559 }
1560 }
1561 rcg_map->clear();
1562 }
1563
NewStepId(int64_t graph_key)1564 uint64 MasterSession::NewStepId(int64_t graph_key) {
1565 if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1566 // StepId must leave the most-significant 7 bits empty for future use.
1567 return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1568 } else {
1569 uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1570 int32_t retry_count = 0;
1571 while (static_cast<int64>(step_id) == CollectiveExecutor::kInvalidId) {
1572 Notification note;
1573 Status status;
1574 env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1575 graph_key, [&status, ¬e](const Status& s) {
1576 status = s;
1577 note.Notify();
1578 });
1579 note.WaitForNotification();
1580 if (!status.ok()) {
1581 LOG(ERROR) << "Bad status from "
1582 "collective_executor_mgr->RefreshStepIdSequence: "
1583 << status << ". Retrying.";
1584 int64_t delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1585 Env::Default()->SleepForMicroseconds(delay_micros);
1586 } else {
1587 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1588 }
1589 }
1590 return step_id;
1591 }
1592 }
1593
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1594 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1595 PartialRunSetupResponse* resp) {
1596 std::vector<string> inputs, outputs, targets;
1597 for (const auto& feed : req->feed()) {
1598 inputs.push_back(feed);
1599 }
1600 for (const auto& fetch : req->fetch()) {
1601 outputs.push_back(fetch);
1602 }
1603 for (const auto& target : req->target()) {
1604 targets.push_back(target);
1605 }
1606
1607 string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1608
1609 ReffedClientGraph* rcg = nullptr;
1610
1611 // Prepare.
1612 BuildGraphOptions opts;
1613 BuildBuildGraphOptions(*req, &opts);
1614 int64_t count = 0;
1615 TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1616
1617 rcg->Ref();
1618 RunState* run_state =
1619 new RunState(inputs, outputs, rcg,
1620 NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1621 {
1622 mutex_lock l(mu_);
1623 partial_runs_.emplace(
1624 std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1625 }
1626
1627 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1628
1629 resp->set_partial_run_handle(handle);
1630 return Status::OK();
1631 }
1632
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1633 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1634 MutableRunStepResponseWrapper* resp) {
1635 UpdateLastAccessTime();
1636 {
1637 mutex_lock l(mu_);
1638 if (closed_) {
1639 return errors::FailedPrecondition("Session is closed.");
1640 }
1641 ++num_running_;
1642 // Note: all code paths must eventually call MarkRunCompletion()
1643 // in order to appropriate decrement the num_running_ counter.
1644 }
1645 Status status;
1646 if (!req.partial_run_handle().empty()) {
1647 status = DoPartialRun(opts, req, resp);
1648 } else {
1649 status = DoRunWithLocalExecution(opts, req, resp);
1650 }
1651 return status;
1652 }
1653
1654 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1655 void MasterSession::MarkRunCompletion() {
1656 mutex_lock l(mu_);
1657 --num_running_;
1658 if (num_running_ == 0) {
1659 num_running_is_zero_.notify_all();
1660 }
1661 }
1662
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1663 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1664 // Registers subgraphs if haven't done so.
1665 PartitionOptions popts;
1666 popts.node_to_loc = SplitByWorker;
1667 // The closures popts.{new_name,get_incarnation} are called synchronously in
1668 // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1669 // "this" alive during the closure.
1670 popts.new_name = [this](const string& prefix) {
1671 mutex_lock l(mu_);
1672 return strings::StrCat(prefix, "_S", next_node_id_++);
1673 };
1674 popts.get_incarnation = [this](const string& name) -> int64 {
1675 Device* d = devices_->FindDeviceByName(name);
1676 if (d == nullptr) {
1677 return PartitionOptions::kIllegalIncarnation;
1678 } else {
1679 return d->attributes().incarnation();
1680 }
1681 };
1682 popts.control_flow_added = false;
1683 const bool enable_bfloat16_sendrecv =
1684 session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1685 popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1686 if (e->IsControlEdge()) {
1687 return DT_FLOAT;
1688 }
1689 DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1690 if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1691 return DT_BFLOAT16;
1692 } else {
1693 return dtype;
1694 }
1695 };
1696 if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1697 popts.scheduling_for_recvs = true;
1698 popts.need_to_record_start_times = true;
1699 }
1700
1701 TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1702
1703 return Status::OK();
1704 }
1705
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1706 Status MasterSession::DoPartialRun(CallOptions* opts,
1707 const RunStepRequestWrapper& req,
1708 MutableRunStepResponseWrapper* resp) {
1709 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1710 const string& prun_handle = req.partial_run_handle();
1711 RunState* run_state = nullptr;
1712 {
1713 mutex_lock l(mu_);
1714 auto it = partial_runs_.find(prun_handle);
1715 if (it == partial_runs_.end()) {
1716 return errors::InvalidArgument(
1717 "Must run PartialRunSetup before performing partial runs");
1718 }
1719 run_state = it->second.get();
1720 }
1721 // CollectiveOps are not supported in partial runs.
1722 if (req.options().experimental().collective_graph_key() !=
1723 BuildGraphOptions::kNoCollectiveGraphKey) {
1724 return errors::InvalidArgument(
1725 "PartialRun does not support Collective ops. collective_graph_key "
1726 "must be kNoCollectiveGraphKey.");
1727 }
1728
1729 // If this is the first partial run, initialize the PerStepState.
1730 if (!run_state->step_started) {
1731 run_state->step_started = true;
1732 PerStepState pss;
1733
1734 const auto count = run_state->count;
1735 pss.collect_timeline =
1736 req.options().trace_level() == RunOptions::FULL_TRACE;
1737 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1738 pss.report_tensor_allocations_upon_oom =
1739 req.options().report_tensor_allocations_upon_oom();
1740
1741 // Build the cost model every 'build_cost_model_every' steps after skipping
1742 // an
1743 // initial 'build_cost_model_after' steps.
1744 const int64_t build_cost_model_after =
1745 session_opts_.config.graph_options().build_cost_model_after();
1746 const int64_t build_cost_model_every =
1747 session_opts_.config.graph_options().build_cost_model();
1748 pss.collect_costs =
1749 build_cost_model_every > 0 &&
1750 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1751 pss.collect_partition_graphs = req.options().output_partition_graphs();
1752
1753 std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1754 run_state->step_id, count, req.options());
1755 if (ph) {
1756 pss.collect_timeline = true;
1757 pss.collect_rpcs = ph->should_collect_rpcs();
1758 }
1759
1760 run_state->pss = std::move(pss);
1761 run_state->ph = std::move(ph);
1762 }
1763
1764 // Make sure that this is a new set of feeds that are still pending.
1765 for (size_t i = 0; i < req.num_feeds(); ++i) {
1766 const string& feed = req.feed_name(i);
1767 auto it = run_state->pending_inputs.find(feed);
1768 if (it == run_state->pending_inputs.end()) {
1769 return errors::InvalidArgument(
1770 "The feed ", feed, " was not specified in partial_run_setup.");
1771 } else if (it->second) {
1772 return errors::InvalidArgument("The feed ", feed,
1773 " has already been fed.");
1774 }
1775 }
1776 // Check that this is a new set of fetches that are still pending.
1777 for (size_t i = 0; i < req.num_fetches(); ++i) {
1778 const string& fetch = req.fetch_name(i);
1779 auto it = run_state->pending_outputs.find(fetch);
1780 if (it == run_state->pending_outputs.end()) {
1781 return errors::InvalidArgument(
1782 "The fetch ", fetch, " was not specified in partial_run_setup.");
1783 } else if (it->second) {
1784 return errors::InvalidArgument("The fetch ", fetch,
1785 " has already been fetched.");
1786 }
1787 }
1788
1789 // Ensure that the requested fetches can be computed from the provided feeds.
1790 {
1791 mutex_lock l(mu_);
1792 TF_RETURN_IF_ERROR(
1793 run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1794 }
1795
1796 // Determine if this partial run satisfies all the pending inputs and outputs.
1797 for (size_t i = 0; i < req.num_feeds(); ++i) {
1798 auto it = run_state->pending_inputs.find(req.feed_name(i));
1799 it->second = true;
1800 }
1801 for (size_t i = 0; i < req.num_fetches(); ++i) {
1802 auto it = run_state->pending_outputs.find(req.fetch_name(i));
1803 it->second = true;
1804 }
1805 bool is_last_partial_run = run_state->PendingDone();
1806
1807 Status s = run_state->rcg->RunPartitions(
1808 env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1809 resp, &cancellation_manager_, is_last_partial_run);
1810
1811 // Delete the run state if there is an error or all fetches are done.
1812 if (!s.ok() || is_last_partial_run) {
1813 ReffedClientGraph* rcg = run_state->rcg;
1814 run_state->pss.end_micros = Env::Default()->NowMicros();
1815 // Schedule post-processing and cleanup to be done asynchronously.
1816 Ref();
1817 rcg->Ref();
1818 rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1819 req.options(), resp->mutable_metadata());
1820 cleanup.release(); // MarkRunCompletion called in done closure.
1821 rcg->CleanupPartitionsAsync(
1822 run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1823 if (!s.ok()) {
1824 LOG(ERROR) << "Cleanup partition error: " << s;
1825 }
1826 rcg->Unref();
1827 MarkRunCompletion();
1828 Unref();
1829 });
1830 mutex_lock l(mu_);
1831 partial_runs_.erase(prun_handle);
1832 }
1833 return s;
1834 }
1835
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64_t rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1836 Status MasterSession::CreateDebuggerState(
1837 const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1838 int64_t rcg_execution_count,
1839 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1840 TF_RETURN_IF_ERROR(
1841 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1842
1843 std::vector<string> input_names;
1844 for (size_t i = 0; i < req.num_feeds(); ++i) {
1845 input_names.push_back(req.feed_name(i));
1846 }
1847 std::vector<string> output_names;
1848 for (size_t i = 0; i < req.num_fetches(); ++i) {
1849 output_names.push_back(req.fetch_name(i));
1850 }
1851 std::vector<string> target_names;
1852 for (size_t i = 0; i < req.num_targets(); ++i) {
1853 target_names.push_back(req.target_name(i));
1854 }
1855
1856 // TODO(cais): We currently use -1 as a dummy value for session run count.
1857 // While this counter value is straightforward to define and obtain for
1858 // DirectSessions, it is less so for non-direct Sessions. Devise a better
1859 // way to get its value when the need arises.
1860 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1861 debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1862 input_names, output_names, target_names));
1863
1864 return Status::OK();
1865 }
1866
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64_t count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1867 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1868 const RunOptions& run_options,
1869 uint64 step_id, int64_t count,
1870 PerStepState* out_pss,
1871 std::unique_ptr<ProfileHandler>* out_ph) {
1872 out_pss->collect_timeline =
1873 run_options.trace_level() == RunOptions::FULL_TRACE;
1874 out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1875 out_pss->report_tensor_allocations_upon_oom =
1876 run_options.report_tensor_allocations_upon_oom();
1877 // Build the cost model every 'build_cost_model_every' steps after skipping an
1878 // initial 'build_cost_model_after' steps.
1879 const int64_t build_cost_model_after =
1880 session_opts_.config.graph_options().build_cost_model_after();
1881 const int64_t build_cost_model_every =
1882 session_opts_.config.graph_options().build_cost_model();
1883 out_pss->collect_costs =
1884 build_cost_model_every > 0 &&
1885 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1886 out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1887
1888 *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1889 if (*out_ph) {
1890 out_pss->collect_timeline = true;
1891 out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1892 }
1893 }
1894
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1895 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1896 uint64 step_id,
1897 const RunOptions& run_options,
1898 PerStepState* pss,
1899 const std::unique_ptr<ProfileHandler>& ph,
1900 const Status& run_status,
1901 RunMetadata* out_run_metadata) {
1902 Status s = run_status;
1903 if (s.ok()) {
1904 pss->end_micros = Env::Default()->NowMicros();
1905 if (rcg->collective_graph_key() !=
1906 BuildGraphOptions::kNoCollectiveGraphKey) {
1907 env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1908 step_id);
1909 }
1910 // Schedule post-processing and cleanup to be done asynchronously.
1911 rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1912 } else if (errors::IsCancelled(s)) {
1913 mutex_lock l(mu_);
1914 if (closed_) {
1915 if (garbage_collected_) {
1916 s = errors::Cancelled(
1917 "Step was cancelled because the session was garbage collected due "
1918 "to inactivity.");
1919 } else {
1920 s = errors::Cancelled(
1921 "Step was cancelled by an explicit call to `Session::Close()`.");
1922 }
1923 }
1924 }
1925 Ref();
1926 rcg->Ref();
1927 rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1928 if (!s.ok()) {
1929 LOG(ERROR) << "Cleanup partition error: " << s;
1930 }
1931 rcg->Unref();
1932 MarkRunCompletion();
1933 Unref();
1934 });
1935 return s;
1936 }
1937
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1938 Status MasterSession::DoRunWithLocalExecution(
1939 CallOptions* opts, const RunStepRequestWrapper& req,
1940 MutableRunStepResponseWrapper* resp) {
1941 VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1942 PerStepState pss;
1943 pss.start_micros = Env::Default()->NowMicros();
1944 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1945
1946 // Prepare.
1947 BuildGraphOptions bgopts;
1948 BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1949 ReffedClientGraph* rcg = nullptr;
1950 int64_t count;
1951 TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1952
1953 // Unref "rcg" when out of scope.
1954 core::ScopedUnref unref(rcg);
1955
1956 std::unique_ptr<DebuggerStateInterface> debugger_state;
1957 const DebugOptions& debug_options = req.options().debug_options();
1958
1959 if (!debug_options.debug_tensor_watch_opts().empty()) {
1960 TF_RETURN_IF_ERROR(
1961 CreateDebuggerState(debug_options, req, count, &debugger_state));
1962 }
1963 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1964
1965 // Keeps the highest 8 bits 0x01: we reserve some bits of the
1966 // step_id for future use.
1967 uint64 step_id = NewStepId(rcg->collective_graph_key());
1968 TRACEPRINTF("stepid %llu", step_id);
1969
1970 std::unique_ptr<ProfileHandler> ph;
1971 FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1972
1973 if (pss.collect_partition_graphs &&
1974 session_opts_.config.experimental().disable_output_partition_graphs()) {
1975 return errors::InvalidArgument(
1976 "RunOptions.output_partition_graphs() is not supported when "
1977 "disable_output_partition_graphs is true.");
1978 }
1979
1980 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1981 &cancellation_manager_, false);
1982
1983 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
1984 return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1985 resp->mutable_metadata());
1986 }
1987
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1988 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1989 MakeCallableResponse* resp) {
1990 UpdateLastAccessTime();
1991
1992 BuildGraphOptions opts;
1993 opts.callable_options = req.options();
1994 opts.use_function_convention = false;
1995
1996 ReffedClientGraph* callable;
1997
1998 {
1999 mutex_lock l(mu_);
2000 if (closed_) {
2001 return errors::FailedPrecondition("Session is closed.");
2002 }
2003 std::unique_ptr<ClientGraph> client_graph;
2004 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2005 callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2006 session_opts_, stats_publisher_factory_,
2007 false /* is_partial */, get_worker_cache(),
2008 !should_delete_worker_sessions_);
2009 }
2010
2011 Status s = BuildAndRegisterPartitions(callable);
2012 if (!s.ok()) {
2013 callable->Unref();
2014 return s;
2015 }
2016
2017 uint64 handle;
2018 {
2019 mutex_lock l(mu_);
2020 handle = next_callable_handle_++;
2021 callables_[handle] = callable;
2022 }
2023
2024 resp->set_handle(handle);
2025 return Status::OK();
2026 }
2027
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)2028 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2029 const RunCallableRequest& req,
2030 RunCallableResponse* resp) {
2031 VLOG(2) << "DoRunCallable req: " << req.DebugString();
2032 PerStepState pss;
2033 pss.start_micros = Env::Default()->NowMicros();
2034 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2035
2036 // Prepare.
2037 int64_t count = rcg->get_and_increment_execution_count();
2038
2039 const uint64 step_id = NewStepId(rcg->collective_graph_key());
2040 TRACEPRINTF("stepid %llu", step_id);
2041
2042 const RunOptions& run_options = rcg->callable_options().run_options();
2043
2044 if (run_options.timeout_in_ms() != 0) {
2045 opts->SetTimeout(run_options.timeout_in_ms());
2046 }
2047
2048 std::unique_ptr<ProfileHandler> ph;
2049 FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2050 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2051 &cancellation_manager_);
2052 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
2053 return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2054 resp->mutable_metadata());
2055 }
2056
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)2057 Status MasterSession::RunCallable(CallOptions* opts,
2058 const RunCallableRequest& req,
2059 RunCallableResponse* resp) {
2060 UpdateLastAccessTime();
2061 ReffedClientGraph* callable;
2062 {
2063 mutex_lock l(mu_);
2064 if (closed_) {
2065 return errors::FailedPrecondition("Session is closed.");
2066 }
2067 int64_t handle = req.handle();
2068 if (handle >= next_callable_handle_) {
2069 return errors::InvalidArgument("No such callable handle: ", handle);
2070 }
2071 auto iter = callables_.find(req.handle());
2072 if (iter == callables_.end()) {
2073 return errors::InvalidArgument(
2074 "Attempted to run callable after handle was released: ", handle);
2075 }
2076 callable = iter->second;
2077 callable->Ref();
2078 ++num_running_;
2079 }
2080 core::ScopedUnref unref_callable(callable);
2081 return DoRunCallable(opts, callable, req, resp);
2082 }
2083
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)2084 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2085 ReleaseCallableResponse* resp) {
2086 UpdateLastAccessTime();
2087 ReffedClientGraph* to_unref = nullptr;
2088 {
2089 mutex_lock l(mu_);
2090 auto iter = callables_.find(req.handle());
2091 if (iter != callables_.end()) {
2092 to_unref = iter->second;
2093 callables_.erase(iter);
2094 }
2095 }
2096 if (to_unref != nullptr) {
2097 to_unref->Unref();
2098 }
2099 return Status::OK();
2100 }
2101
Close()2102 Status MasterSession::Close() {
2103 {
2104 mutex_lock l(mu_);
2105 closed_ = true; // All subsequent calls to Run() or Extend() will fail.
2106 }
2107 cancellation_manager_.StartCancel();
2108 std::vector<ReffedClientGraph*> to_unref;
2109 {
2110 mutex_lock l(mu_);
2111 while (num_running_ != 0) {
2112 num_running_is_zero_.wait(l);
2113 }
2114 ClearRunsTable(&to_unref, &run_graphs_);
2115 ClearRunsTable(&to_unref, &partial_run_graphs_);
2116 ClearRunsTable(&to_unref, &callables_);
2117 }
2118 for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2119 if (should_delete_worker_sessions_) {
2120 Status s = DeleteWorkerSessions();
2121 if (!s.ok()) {
2122 LOG(WARNING) << s;
2123 }
2124 }
2125 return Status::OK();
2126 }
2127
GarbageCollect()2128 void MasterSession::GarbageCollect() {
2129 {
2130 mutex_lock l(mu_);
2131 closed_ = true;
2132 garbage_collected_ = true;
2133 }
2134 cancellation_manager_.StartCancel();
2135 Unref();
2136 }
2137
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64_t count)2138 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2139 const std::vector<string>& output_names,
2140 ReffedClientGraph* rcg, const uint64 step_id,
2141 const int64_t count)
2142 : rcg(rcg), step_id(step_id), count(count) {
2143 // Initially all the feeds and fetches are pending.
2144 for (auto& name : input_names) {
2145 pending_inputs[name] = false;
2146 }
2147 for (auto& name : output_names) {
2148 pending_outputs[name] = false;
2149 }
2150 }
2151
~RunState()2152 MasterSession::RunState::~RunState() {
2153 if (rcg) rcg->Unref();
2154 }
2155
PendingDone() const2156 bool MasterSession::RunState::PendingDone() const {
2157 for (const auto& it : pending_inputs) {
2158 if (!it.second) return false;
2159 }
2160 for (const auto& it : pending_outputs) {
2161 if (!it.second) return false;
2162 }
2163 return true;
2164 }
2165
2166 } // end namespace tensorflow
2167