1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/scheduler.h"
28 #include "tensorflow/core/distributed_runtime/worker_cache.h"
29 #include "tensorflow/core/distributed_runtime/worker_interface.h"
30 #include "tensorflow/core/framework/allocation_description.pb.h"
31 #include "tensorflow/core/framework/collective.h"
32 #include "tensorflow/core/framework/cost_graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/graph/graph_partition.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/blocking_counter.h"
40 #include "tensorflow/core/lib/core/notification.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/gtl/map_util.h"
46 #include "tensorflow/core/lib/random/random.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/macros.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/public/session_options.h"
57
58 namespace tensorflow {
59
60 // MasterSession wraps ClientGraph in a reference counted object.
61 // This way, MasterSession can clear up the cache mapping Run requests to
62 // compiled graphs while the compiled graph is still being used.
63 //
64 // TODO(zhifengc): Cleanup this class. It's becoming messy.
65 class MasterSession::ReffedClientGraph : public core::RefCounted {
66 public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)67 ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
68 std::unique_ptr<ClientGraph> client_graph,
69 const SessionOptions& session_opts,
70 const StatsPublisherFactory& stats_publisher_factory,
71 bool is_partial, WorkerCacheInterface* worker_cache,
72 bool should_deregister)
73 : session_handle_(handle),
74 bg_opts_(bopts),
75 client_graph_before_register_(std::move(client_graph)),
76 session_opts_(session_opts),
77 is_partial_(is_partial),
78 callable_opts_(bopts.callable_options),
79 worker_cache_(worker_cache),
80 should_deregister_(should_deregister),
81 collective_graph_key_(
82 client_graph_before_register_->collective_graph_key) {
83 VLOG(1) << "Created ReffedClientGraph for node with "
84 << client_graph_before_register_->graph.num_node_ids();
85
86 stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
87
88 // Initialize a name to node map for processing device stats.
89 for (Node* n : client_graph_before_register_->graph.nodes()) {
90 name_to_node_details_.emplace(
91 n->name(),
92 NodeDetails(n->type_string(),
93 strings::StrCat(
94 "(", str_util::Join(n->requested_inputs(), ", "))));
95 }
96 }
97
~ReffedClientGraph()98 ~ReffedClientGraph() override {
99 if (should_deregister_) {
100 DeregisterPartitions();
101 } else {
102 for (Part& part : partitions_) {
103 worker_cache_->ReleaseWorker(part.name, part.worker);
104 }
105 }
106 }
107
callable_options()108 const CallableOptions& callable_options() { return callable_opts_; }
109
build_graph_options()110 const BuildGraphOptions& build_graph_options() { return bg_opts_; }
111
collective_graph_key()112 int64 collective_graph_key() { return collective_graph_key_; }
113
GetProfileHandler(uint64 step,int64 execution_count,const RunOptions & ropts)114 std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
115 int64 execution_count,
116 const RunOptions& ropts) {
117 return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
118 }
119
get_and_increment_execution_count()120 int64 get_and_increment_execution_count() {
121 return execution_count_.fetch_add(1);
122 }
123
124 // Turn RPC logging on or off, both at the WorkerCache used by this
125 // master process, and at each remote worker in use for the current
126 // partitions.
SetRPCLogging(bool active)127 void SetRPCLogging(bool active) {
128 worker_cache_->SetLogging(active);
129 // Logging is a best-effort activity, so we make async calls to turn
130 // it on/off and don't make use of the responses.
131 for (auto& p : partitions_) {
132 LoggingRequest* req = new LoggingRequest;
133 if (active) {
134 req->set_enable_rpc_logging(true);
135 } else {
136 req->set_disable_rpc_logging(true);
137 }
138 LoggingResponse* resp = new LoggingResponse;
139 Ref();
140 p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
141 delete req;
142 delete resp;
143 // ReffedClientGraph owns p.worker so we need to hold a ref to
144 // ensure that the method doesn't attempt to access p.worker after
145 // ReffedClient graph has deleted it.
146 // TODO(suharshs): Simplify this ownership model.
147 Unref();
148 });
149 }
150 }
151
152 // Retrieve all RPC logs data accumulated for the current step, both
153 // from the local WorkerCache in use by this master process and from
154 // all the remote workers executing the remote partitions.
RetrieveLogs(int64 step_id,StepStats * ss)155 void RetrieveLogs(int64 step_id, StepStats* ss) {
156 // Get the local data first, because it sets *ss without merging.
157 worker_cache_->RetrieveLogs(step_id, ss);
158
159 // Then merge in data from all the remote workers.
160 LoggingRequest req;
161 req.add_fetch_step_id(step_id);
162 int waiting_for = partitions_.size();
163 if (waiting_for > 0) {
164 mutex scoped_mu;
165 BlockingCounter all_done(waiting_for);
166 for (auto& p : partitions_) {
167 LoggingResponse* resp = new LoggingResponse;
168 p.worker->LoggingAsync(
169 &req, resp,
170 [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
171 {
172 mutex_lock l(scoped_mu);
173 if (s.ok()) {
174 for (auto& lss : resp->step()) {
175 if (step_id != lss.step_id()) {
176 LOG(ERROR) << "Wrong step_id in LoggingResponse";
177 continue;
178 }
179 ss->MergeFrom(lss.step_stats());
180 }
181 }
182 delete resp;
183 }
184 // Must not decrement all_done until out of critical section where
185 // *ss is updated.
186 all_done.DecrementCount();
187 });
188 }
189 all_done.Wait();
190 }
191 }
192
193 // Local execution methods.
194
195 // Partitions the graph into subgraphs and registers them on
196 // workers.
197 Status RegisterPartitions(PartitionOptions popts);
198
199 // Runs one step of all partitions.
200 Status RunPartitions(const MasterEnv* env, int64 step_id,
201 int64 execution_count, PerStepState* pss,
202 CallOptions* opts, const RunStepRequestWrapper& req,
203 MutableRunStepResponseWrapper* resp,
204 CancellationManager* cm, const bool is_last_partial_run);
205 Status RunPartitions(const MasterEnv* env, int64 step_id,
206 int64 execution_count, PerStepState* pss,
207 CallOptions* call_opts, const RunCallableRequest& req,
208 RunCallableResponse* resp, CancellationManager* cm);
209
210 // Calls workers to cleanup states for the step "step_id". Calls
211 // `done` when all cleanup RPCs have completed.
212 void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
213
214 // Post-processing of any runtime statistics gathered during execution.
215 void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
216 const RunOptions& options, RunMetadata* resp);
217 void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
218 bool is_rpc);
219 // Checks that the requested fetches can be computed from the provided feeds.
220 Status CheckFetches(const RunStepRequestWrapper& req,
221 const RunState* run_state,
222 GraphExecutionState* execution_state);
223
224 private:
225 const string session_handle_;
226 const BuildGraphOptions bg_opts_;
227
228 // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
229 std::unique_ptr<ClientGraph> client_graph_before_register_ GUARDED_BY(mu_);
230 const SessionOptions session_opts_;
231 const bool is_partial_;
232 const CallableOptions callable_opts_;
233 WorkerCacheInterface* const worker_cache_; // Not owned.
234
235 struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails236 explicit NodeDetails(string type_string, string detail_text)
237 : type_string(std::move(type_string)),
238 detail_text(std::move(detail_text)) {}
239 const string type_string;
240 const string detail_text;
241 };
242 std::unordered_map<string, NodeDetails> name_to_node_details_;
243
244 const bool should_deregister_;
245 const int64 collective_graph_key_;
246 std::atomic<int64> execution_count_ = {0};
247
248 // Graph partitioned into per-location subgraphs.
249 struct Part {
250 // Worker name.
251 string name;
252
253 // Maps feed names to rendezvous keys. Empty most of the time.
254 std::unordered_map<string, string> feed_key;
255
256 // Maps rendezvous keys to fetch names. Empty most of the time.
257 std::unordered_map<string, string> key_fetch;
258
259 // The interface to the worker. Owned.
260 WorkerInterface* worker = nullptr;
261
262 // After registeration with the worker, graph_handle identifies
263 // this partition on the worker.
264 string graph_handle;
265
Parttensorflow::MasterSession::ReffedClientGraph::Part266 Part() : feed_key(3), key_fetch(3) {}
267 };
268
269 // partitions_ is immutable after RegisterPartitions() call
270 // finishes. RunPartitions() can access partitions_ safely without
271 // acquiring locks.
272 std::vector<Part> partitions_;
273
274 mutable mutex mu_;
275
276 // Partition initialization and registration only needs to happen
277 // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
278 // indicates the initialization is ongoing.
279 Notification init_done_;
280
281 // init_result_ remembers the initialization error if any.
282 Status init_result_ GUARDED_BY(mu_);
283
284 std::unique_ptr<StatsPublisherInterface> stats_publisher_;
285
DetailText(const NodeDetails & details,const NodeExecStats & stats)286 string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
287 int64 tot = 0;
288 for (auto& no : stats.output()) {
289 tot += no.tensor_description().allocation_description().requested_bytes();
290 }
291 string bytes;
292 if (tot >= 0.1 * 1048576.0) {
293 bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
294 }
295 return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
296 details.detail_text);
297 }
298
299 // Send/Recv nodes that are the result of client-added
300 // feeds and fetches must be tracked so that the tensors
301 // can be added to the local rendezvous.
302 static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
303 const PartitionOptions& popts);
304
305 // The actual graph partitioning and registration implementation.
306 Status DoBuildPartitions(
307 PartitionOptions popts, ClientGraph* client_graph,
308 std::unordered_map<string, GraphDef>* out_partitions);
309 Status DoRegisterPartitions(
310 const PartitionOptions& popts,
311 std::unordered_map<string, GraphDef> graph_partitions);
312
313 // Prepares a number of calls to workers. One call per partition.
314 // This is a generic method that handles Run, PartialRun, and RunCallable.
315 template <class FetchListType, class ClientRequestType,
316 class ClientResponseType>
317 Status RunPartitionsHelper(
318 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
319 const FetchListType& fetches, const MasterEnv* env, int64 step_id,
320 int64 execution_count, PerStepState* pss, CallOptions* call_opts,
321 const ClientRequestType& req, ClientResponseType* resp,
322 CancellationManager* cm, bool is_last_partial_run);
323
324 // Deregisters the partitions on the workers. Called in the
325 // destructor and does not wait for the rpc completion.
326 void DeregisterPartitions();
327
328 TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
329 };
330
RegisterPartitions(PartitionOptions popts)331 Status MasterSession::ReffedClientGraph::RegisterPartitions(
332 PartitionOptions popts) {
333 { // Ensure register once.
334 mu_.lock();
335 if (client_graph_before_register_) {
336 // The `ClientGraph` is no longer needed after partitions are registered.
337 // Since it can account for a large amount of memory, we consume it here,
338 // and it will be freed after concluding with registration.
339
340 std::unique_ptr<ClientGraph> client_graph;
341 std::swap(client_graph_before_register_, client_graph);
342 mu_.unlock();
343 std::unordered_map<string, GraphDef> graph_defs;
344 popts.flib_def = client_graph->flib_def.get();
345 Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
346 if (s.ok()) {
347 // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
348 // valid after the call to DoRegisterPartitions begins, so
349 // `stats_publisher_` must make a copy if it wants to retain the
350 // GraphDef objects.
351 std::vector<const GraphDef*> graph_defs_for_publishing;
352 graph_defs_for_publishing.reserve(partitions_.size());
353 for (const auto& name_def : graph_defs) {
354 graph_defs_for_publishing.push_back(&name_def.second);
355 }
356 stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
357 s = DoRegisterPartitions(popts, std::move(graph_defs));
358 }
359 mu_.lock();
360 init_result_ = s;
361 init_done_.Notify();
362 } else {
363 mu_.unlock();
364 init_done_.WaitForNotification();
365 mu_.lock();
366 }
367 const Status result = init_result_;
368 mu_.unlock();
369 return result;
370 }
371 }
372
SplitByWorker(const Node * node)373 static string SplitByWorker(const Node* node) {
374 string task;
375 string device;
376 CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
377 &device))
378 << "node: " << node->name() << " dev: " << node->assigned_device_name();
379 return task;
380 }
381
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)382 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
383 Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
384 for (int i = 0; i < graph_def.node_size(); ++i) {
385 const NodeDef& ndef = graph_def.node(i);
386 const bool is_recv = ndef.op() == "_Recv";
387 const bool is_send = ndef.op() == "_Send";
388
389 if (is_recv || is_send) {
390 // Only send/recv nodes that were added as feeds and fetches
391 // (client-terminated) should be tracked. Other send/recv nodes
392 // are for transferring data between partitions / memory spaces.
393 bool client_terminated;
394 TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
395 if (client_terminated) {
396 string name;
397 TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
398 string send_device;
399 TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
400 string recv_device;
401 TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
402 uint64 send_device_incarnation;
403 TF_CHECK_OK(
404 GetNodeAttr(ndef, "send_device_incarnation",
405 reinterpret_cast<int64*>(&send_device_incarnation)));
406 const string& key =
407 Rendezvous::CreateKey(send_device, send_device_incarnation,
408 recv_device, name, FrameAndIter(0, 0));
409
410 if (is_recv) {
411 part->feed_key.insert({name, key});
412 } else {
413 part->key_fetch.insert({key, name});
414 }
415 }
416 }
417 }
418 }
419
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)420 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
421 PartitionOptions popts, ClientGraph* client_graph,
422 std::unordered_map<string, GraphDef>* out_partitions) {
423 if (popts.need_to_record_start_times) {
424 CostModel cost_model(true);
425 cost_model.InitFromGraph(client_graph->graph);
426 // TODO(yuanbyu): Use the real cost model.
427 // execution_state_->MergeFromGlobal(&cost_model);
428 SlackAnalysis sa(&client_graph->graph, &cost_model);
429 sa.ComputeAsap(&popts.start_times);
430 }
431
432 // Partition the graph.
433 return Partition(popts, &client_graph->graph, out_partitions);
434 }
435
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)436 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
437 const PartitionOptions& popts,
438 std::unordered_map<string, GraphDef> graph_partitions) {
439 partitions_.reserve(graph_partitions.size());
440 Status s;
441 for (auto& name_def : graph_partitions) {
442 partitions_.emplace_back();
443 Part* part = &partitions_.back();
444 part->name = name_def.first;
445 TrackFeedsAndFetches(part, name_def.second, popts);
446 part->worker = worker_cache_->CreateWorker(part->name);
447 if (part->worker == nullptr) {
448 s = errors::NotFound("worker ", part->name);
449 break;
450 }
451 }
452 if (!s.ok()) {
453 for (Part& part : partitions_) {
454 worker_cache_->ReleaseWorker(part.name, part.worker);
455 part.worker = nullptr;
456 }
457 return s;
458 }
459 struct Call {
460 RegisterGraphRequest req;
461 RegisterGraphResponse resp;
462 Status status;
463 };
464 const int num = partitions_.size();
465 gtl::InlinedVector<Call, 4> calls(num);
466 BlockingCounter done(num);
467 for (int i = 0; i < num; ++i) {
468 const Part& part = partitions_[i];
469 Call* c = &calls[i];
470 c->req.set_session_handle(session_handle_);
471 c->req.set_create_worker_session_called(!should_deregister_);
472 c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
473 *c->req.mutable_graph_options() = session_opts_.config.graph_options();
474 *c->req.mutable_debug_options() =
475 callable_opts_.run_options().debug_options();
476 c->req.set_collective_graph_key(collective_graph_key_);
477 VLOG(2) << "Register " << c->req.graph_def().DebugString();
478 auto cb = [c, &done](const Status& s) {
479 c->status = s;
480 done.DecrementCount();
481 };
482 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
483 }
484 done.Wait();
485 for (int i = 0; i < num; ++i) {
486 Call* c = &calls[i];
487 s.Update(c->status);
488 partitions_[i].graph_handle = c->resp.graph_handle();
489 }
490 return s;
491 }
492
493 // Helper class to manage "num" parallel RunGraph calls.
494 class RunManyGraphs {
495 public:
RunManyGraphs(int num)496 explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
497
~RunManyGraphs()498 ~RunManyGraphs() {}
499
500 // Returns the index-th call.
501 struct Call {
502 CallOptions opts;
503 std::unique_ptr<MutableRunGraphRequestWrapper> req;
504 std::unique_ptr<MutableRunGraphResponseWrapper> resp;
505 };
get(int index)506 Call* get(int index) { return &calls_[index]; }
507
508 // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)509 void WhenDone(int index, const Status& s) {
510 TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
511 auto resp = get(index)->resp.get();
512 if (resp->status_code() != error::Code::OK) {
513 // resp->status_code will only be non-OK if s.ok().
514 mutex_lock l(mu_);
515 ReportBadStatus(
516 Status(resp->status_code(), resp->status_error_message()));
517 } else if (!s.ok()) {
518 mutex_lock l(mu_);
519 ReportBadStatus(s);
520 }
521 pending_.DecrementCount();
522 }
523
StartCancel()524 void StartCancel() {
525 mutex_lock l(mu_);
526 ReportBadStatus(errors::Cancelled("RunManyGraphs"));
527 }
528
Wait()529 void Wait() { pending_.Wait(); }
530
status() const531 Status status() const {
532 mutex_lock l(mu_);
533 return status_group_.as_status();
534 }
535
536 private:
537 gtl::InlinedVector<Call, 4> calls_;
538
539 BlockingCounter pending_;
540 mutable mutex mu_;
541 StatusGroup status_group_ GUARDED_BY(mu_);
542
ReportBadStatus(const Status & s)543 void ReportBadStatus(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
544 // Start cancellation if we aren't already in an error state.
545 if (status_group_.ok()) {
546 for (Call& call : calls_) {
547 call.opts.StartCancel();
548 }
549 }
550
551 status_group_.Update(s);
552 }
553
554 TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
555 };
556
557 namespace {
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)558 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
559 MutableRunGraphRequestWrapper* worker_req,
560 size_t index, const string& send_key) {
561 return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
562 }
563
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)564 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
565 MutableRunGraphRequestWrapper* worker_req,
566 size_t index, const string& send_key) {
567 return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
568 }
569
570 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
571 // in-process messages.
572 struct RunCallableResponseWrapper {
573 RunCallableResponse* resp; // Not owned.
574 std::unordered_map<string, TensorProto> fetch_key_to_protos;
575
mutable_metadatatensorflow::__anon26c900ce0411::RunCallableResponseWrapper576 RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
577
AddTensorFromRunGraphResponsetensorflow::__anon26c900ce0411::RunCallableResponseWrapper578 Status AddTensorFromRunGraphResponse(
579 const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
580 size_t index) {
581 // TODO(b/74355905): Add a specialized implementation that avoids
582 // copying the tensor into the RunCallableResponse when at least
583 // two of the {client, master, worker} are in the same process.
584 return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
585 }
586 };
587 } // namespace
588
589 template <class FetchListType, class ClientRequestType,
590 class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)591 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
592 const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
593 const FetchListType& fetches, const MasterEnv* env, int64 step_id,
594 int64 execution_count, PerStepState* pss, CallOptions* call_opts,
595 const ClientRequestType& req, ClientResponseType* resp,
596 CancellationManager* cm, bool is_last_partial_run) {
597 // Collect execution cost stats on a smoothly decreasing frequency.
598 ExecutorOpts exec_opts;
599 if (pss->report_tensor_allocations_upon_oom) {
600 exec_opts.set_report_tensor_allocations_upon_oom(true);
601 }
602 if (pss->collect_costs) {
603 exec_opts.set_record_costs(true);
604 }
605 if (pss->collect_timeline) {
606 exec_opts.set_record_timeline(true);
607 }
608 if (pss->collect_rpcs) {
609 SetRPCLogging(true);
610 }
611 if (pss->collect_partition_graphs) {
612 exec_opts.set_record_partition_graphs(true);
613 }
614 if (pss->collect_costs || pss->collect_timeline) {
615 pss->step_stats.resize(partitions_.size());
616 }
617
618 const int num = partitions_.size();
619 RunManyGraphs calls(num);
620
621 for (int i = 0; i < num; ++i) {
622 const Part& part = partitions_[i];
623 RunManyGraphs::Call* c = calls.get(i);
624 c->req.reset(part.worker->CreateRunGraphRequest());
625 c->resp.reset(part.worker->CreateRunGraphResponse());
626 if (is_partial_) {
627 c->req->set_is_partial(is_partial_);
628 c->req->set_is_last_partial_run(is_last_partial_run);
629 }
630 c->req->set_session_handle(session_handle_);
631 c->req->set_create_worker_session_called(!should_deregister_);
632 c->req->set_graph_handle(part.graph_handle);
633 c->req->set_step_id(step_id);
634 *c->req->mutable_exec_opts() = exec_opts;
635 c->req->set_store_errors_in_response_body(true);
636 // If any feeds are provided, send the feed values together
637 // in the RunGraph request.
638 // In the partial case, we only want to include feeds provided in the req.
639 // In the non-partial case, all feeds in the request are in the part.
640 // We keep these as separate paths for now, to ensure we aren't
641 // inadvertently slowing down the normal run path.
642 if (is_partial_) {
643 for (const auto& name_index : feeds) {
644 const auto iter = part.feed_key.find(string(name_index.first));
645 if (iter == part.feed_key.end()) {
646 // The provided feed must be for a different partition.
647 continue;
648 }
649 const string& key = iter->second;
650 TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
651 name_index.second, key));
652 }
653 // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
654 // For now, we just iterate through partitions to find the matching key.
655 for (const string& req_fetch : fetches) {
656 for (const auto& key_fetch : part.key_fetch) {
657 if (key_fetch.second == req_fetch) {
658 c->req->add_recv_key(key_fetch.first);
659 break;
660 }
661 }
662 }
663 } else {
664 for (const auto& feed_key : part.feed_key) {
665 const string& feed = feed_key.first;
666 const string& key = feed_key.second;
667 auto iter = feeds.find(feed);
668 if (iter == feeds.end()) {
669 return errors::Internal("No feed index found for feed: ", feed);
670 }
671 const int64 feed_index = iter->second;
672 TF_RETURN_IF_ERROR(
673 AddSendFromClientRequest(req, c->req.get(), feed_index, key));
674 }
675 for (const auto& key_fetch : part.key_fetch) {
676 const string& key = key_fetch.first;
677 c->req->add_recv_key(key);
678 }
679 }
680 }
681
682 // Issues RunGraph calls.
683 for (int i = 0; i < num; ++i) {
684 const Part& part = partitions_[i];
685 RunManyGraphs::Call* call = calls.get(i);
686 TRACEPRINTF("Partition %d %s", i, part.name.c_str());
687 part.worker->RunGraphAsync(
688 &call->opts, call->req.get(), call->resp.get(),
689 std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
690 }
691
692 // Waits for the RunGraph calls.
693 call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
694 auto token = cm->get_cancellation_token();
695 const bool success =
696 cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
697 if (!success) {
698 calls.StartCancel();
699 }
700 calls.Wait();
701 call_opts->ClearCancelCallback();
702 if (success) {
703 cm->DeregisterCallback(token);
704 } else {
705 return errors::Cancelled("Step was cancelled");
706 }
707 TF_RETURN_IF_ERROR(calls.status());
708
709 // Collects fetches and metadata.
710 Status status;
711 for (int i = 0; i < num; ++i) {
712 const Part& part = partitions_[i];
713 MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
714 for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
715 auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
716 if (iter == part.key_fetch.end()) {
717 status.Update(errors::Internal("Unexpected fetch key: ",
718 run_graph_resp->recv_key(j)));
719 break;
720 }
721 const string& fetch = iter->second;
722 status.Update(
723 resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
724 if (!status.ok()) {
725 break;
726 }
727 }
728 if (pss->collect_timeline) {
729 pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
730 }
731 if (pss->collect_costs) {
732 CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
733 for (int j = 0; j < cost_graph->node_size(); ++j) {
734 resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
735 cost_graph->mutable_node(j));
736 }
737 }
738 if (pss->collect_partition_graphs) {
739 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
740 resp->mutable_metadata()->mutable_partition_graphs();
741 for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
742 partition_graph_defs->Add()->Swap(
743 run_graph_resp->mutable_partition_graph(i));
744 }
745 }
746 }
747 return status;
748 }
749
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)750 Status MasterSession::ReffedClientGraph::RunPartitions(
751 const MasterEnv* env, int64 step_id, int64 execution_count,
752 PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
753 MutableRunStepResponseWrapper* resp, CancellationManager* cm,
754 const bool is_last_partial_run) {
755 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
756 << execution_count;
757 // Maps the names of fed tensors to their index in `req`.
758 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
759 for (size_t i = 0; i < req.num_feeds(); ++i) {
760 if (!feeds.insert({req.feed_name(i), i}).second) {
761 return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
762 }
763 }
764
765 std::vector<string> fetches;
766 fetches.reserve(req.num_fetches());
767 for (size_t i = 0; i < req.num_fetches(); ++i) {
768 fetches.push_back(req.fetch_name(i));
769 }
770
771 return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
772 call_opts, req, resp, cm, is_last_partial_run);
773 }
774
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)775 Status MasterSession::ReffedClientGraph::RunPartitions(
776 const MasterEnv* env, int64 step_id, int64 execution_count,
777 PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
778 RunCallableResponse* resp, CancellationManager* cm) {
779 VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
780 << execution_count;
781 // Maps the names of fed tensors to their index in `req`.
782 std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
783 for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
784 if (!feeds.insert({callable_opts_.feed(i), i}).second) {
785 // MakeCallable will fail if there are two feeds with the same name.
786 return errors::Internal("Duplicated feeds in callable: ",
787 callable_opts_.feed(i));
788 }
789 }
790
791 // Create a wrapped response object to collect the fetched values and
792 // rearrange them for the RunCallableResponse.
793 RunCallableResponseWrapper wrapped_resp;
794 wrapped_resp.resp = resp;
795
796 TF_RETURN_IF_ERROR(RunPartitionsHelper(
797 feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
798 call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
799
800 // Collects fetches.
801 // TODO(b/74355905): Add a specialized implementation that avoids
802 // copying the tensor into the RunCallableResponse when at least
803 // two of the {client, master, worker} are in the same process.
804 for (const string& fetch : callable_opts_.fetch()) {
805 TensorProto* fetch_proto = resp->mutable_fetch()->Add();
806 auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
807 if (iter == wrapped_resp.fetch_key_to_protos.end()) {
808 return errors::Internal("Worker did not return a value for fetch: ",
809 fetch);
810 }
811 fetch_proto->Swap(&iter->second);
812 }
813 return Status::OK();
814 }
815
816 namespace {
817
818 class CleanupBroadcastHelper {
819 public:
CleanupBroadcastHelper(int64 step_id,int num_calls,StatusCallback done)820 CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
821 : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
822 req_.set_step_id(step_id);
823 }
824
825 // Returns a non-owned pointer to a request buffer for all calls.
request()826 CleanupGraphRequest* request() { return &req_; }
827
828 // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)829 CleanupGraphResponse* response(int i) { return &resps_[i]; }
830
831 // Called when the ith response is received.
call_done(int i,const Status & s)832 void call_done(int i, const Status& s) {
833 bool run_callback = false;
834 Status status_copy;
835 {
836 mutex_lock l(mu_);
837 status_.Update(s);
838 if (--num_pending_ == 0) {
839 run_callback = true;
840 status_copy = status_;
841 }
842 }
843 if (run_callback) {
844 done_(status_copy);
845 // This is the last call, so delete the helper object.
846 delete this;
847 }
848 }
849
850 private:
851 // A single request shared between all workers.
852 CleanupGraphRequest req_;
853 // One response buffer for each worker.
854 gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
855
856 mutex mu_;
857 // Number of requests remaining to be collected.
858 int num_pending_ GUARDED_BY(mu_);
859 // Aggregate status of the operation.
860 Status status_ GUARDED_BY(mu_);
861 // Callback to be called when all operations complete.
862 StatusCallback done_;
863
864 TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
865 };
866
867 } // namespace
868
CleanupPartitionsAsync(int64 step_id,StatusCallback done)869 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
870 int64 step_id, StatusCallback done) {
871 const int num = partitions_.size();
872 // Helper object will be deleted when the final call completes.
873 CleanupBroadcastHelper* helper =
874 new CleanupBroadcastHelper(step_id, num, std::move(done));
875 for (int i = 0; i < num; ++i) {
876 const Part& part = partitions_[i];
877 part.worker->CleanupGraphAsync(
878 helper->request(), helper->response(i),
879 [helper, i](const Status& s) { helper->call_done(i, s); });
880 }
881 }
882
ProcessStats(int64 step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)883 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
884 PerStepState* pss,
885 ProfileHandler* ph,
886 const RunOptions& options,
887 RunMetadata* resp) {
888 if (!pss->collect_costs && !pss->collect_timeline) return;
889
890 // Out-of-band logging data is collected now, during post-processing.
891 if (pss->collect_timeline) {
892 SetRPCLogging(false);
893 RetrieveLogs(step_id, &pss->rpc_stats);
894 }
895 for (size_t i = 0; i < partitions_.size(); ++i) {
896 const StepStats& ss = pss->step_stats[i];
897 if (ph) {
898 for (const auto& ds : ss.dev_stats()) {
899 ProcessDeviceStats(ph, ds, false /*is_rpc*/);
900 }
901 }
902 }
903 if (ph) {
904 for (const auto& ds : pss->rpc_stats.dev_stats()) {
905 ProcessDeviceStats(ph, ds, true /*is_rpc*/);
906 }
907 ph->StepDone(pss->start_micros, pss->end_micros,
908 Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
909 Status::OK());
910 }
911 // Assemble all stats for this timeline into a merged StepStats.
912 if (pss->collect_timeline) {
913 StepStats step_stats_proto;
914 step_stats_proto.Swap(&pss->rpc_stats);
915 for (size_t i = 0; i < partitions_.size(); ++i) {
916 step_stats_proto.MergeFrom(pss->step_stats[i]);
917 pss->step_stats[i].Clear();
918 }
919 pss->step_stats.clear();
920 // Copy the stats back, but only for on-demand profiling to avoid slowing
921 // down calls that trigger the automatic profiling.
922 if (options.trace_level() == RunOptions::FULL_TRACE) {
923 resp->mutable_step_stats()->Swap(&step_stats_proto);
924 } else {
925 // If FULL_TRACE, it can be fetched from Session API, no need for
926 // duplicated publishing.
927 stats_publisher_->PublishStatsProto(step_stats_proto);
928 }
929 }
930 }
931
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)932 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
933 ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
934 const string& dev_name = ds.device();
935 VLOG(1) << "Device " << dev_name << " reports stats for "
936 << ds.node_stats_size() << " nodes";
937 for (const auto& ns : ds.node_stats()) {
938 if (is_rpc) {
939 // We don't have access to a good Node pointer, so we rely on
940 // sufficient data being present in the NodeExecStats.
941 ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
942 ns.timeline_label());
943 } else {
944 auto iter = name_to_node_details_.find(ns.node_name());
945 const bool found_node_in_graph = iter != name_to_node_details_.end();
946 if (!found_node_in_graph && ns.timeline_label().empty()) {
947 // The counter incrementing is not thread-safe. But we don't really
948 // care.
949 // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
950 // more general usage.
951 static int log_counter = 0;
952 if (log_counter < 10) {
953 log_counter++;
954 LOG(WARNING) << "Failed to find node " << ns.node_name()
955 << " for dev " << dev_name;
956 }
957 continue;
958 }
959 const string& optype =
960 found_node_in_graph ? iter->second.type_string : ns.node_name();
961 string details;
962 if (!ns.timeline_label().empty()) {
963 details = ns.timeline_label();
964 } else if (found_node_in_graph) {
965 details = DetailText(iter->second, ns);
966 } else {
967 // Leave details string empty
968 }
969 ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
970 details);
971 }
972 }
973 }
974
975 // TODO(suharshs): Merge with CheckFetches in DirectSession.
976 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
977 // on once at setup time to prevent us from computing the dependencies
978 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)979 Status MasterSession::ReffedClientGraph::CheckFetches(
980 const RunStepRequestWrapper& req, const RunState* run_state,
981 GraphExecutionState* execution_state) {
982 // Build the set of pending feeds that we haven't seen.
983 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
984 for (const auto& input : run_state->pending_inputs) {
985 // Skip if already fed.
986 if (input.second) continue;
987 TensorId id(ParseTensorName(input.first));
988 const Node* n = execution_state->get_node_by_name(string(id.first));
989 if (n == nullptr) {
990 return errors::NotFound("Feed ", input.first, ": not found");
991 }
992 pending_feeds.insert(id);
993 }
994 for (size_t i = 0; i < req.num_feeds(); ++i) {
995 const TensorId id(ParseTensorName(req.feed_name(i)));
996 pending_feeds.erase(id);
997 }
998
999 // Initialize the stack with the fetch nodes.
1000 std::vector<const Node*> stack;
1001 for (size_t i = 0; i < req.num_fetches(); ++i) {
1002 const string& fetch = req.fetch_name(i);
1003 const TensorId id(ParseTensorName(fetch));
1004 const Node* n = execution_state->get_node_by_name(string(id.first));
1005 if (n == nullptr) {
1006 return errors::NotFound("Fetch ", fetch, ": not found");
1007 }
1008 stack.push_back(n);
1009 }
1010
1011 // Any tensor needed for fetches can't be in pending_feeds.
1012 // We need to use the original full graph from execution state.
1013 const Graph* graph = execution_state->full_graph();
1014 std::vector<bool> visited(graph->num_node_ids(), false);
1015 while (!stack.empty()) {
1016 const Node* n = stack.back();
1017 stack.pop_back();
1018
1019 for (const Edge* in_edge : n->in_edges()) {
1020 const Node* in_node = in_edge->src();
1021 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1022 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1023 in_edge->src_output(),
1024 " can't be computed from the feeds"
1025 " that have been fed so far.");
1026 }
1027 if (!visited[in_node->id()]) {
1028 visited[in_node->id()] = true;
1029 stack.push_back(in_node);
1030 }
1031 }
1032 }
1033 return Status::OK();
1034 }
1035
1036 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1037 // result.
DeregisterPartitions()1038 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1039 struct Call {
1040 DeregisterGraphRequest req;
1041 DeregisterGraphResponse resp;
1042 };
1043 for (Part& part : partitions_) {
1044 // The graph handle may be empty if we failed during partition registration.
1045 if (!part.graph_handle.empty()) {
1046 Call* c = new Call;
1047 c->req.set_session_handle(session_handle_);
1048 c->req.set_create_worker_session_called(!should_deregister_);
1049 c->req.set_graph_handle(part.graph_handle);
1050 // NOTE(mrry): We must capture `worker_cache_` since `this`
1051 // could be deleted before the callback is called.
1052 WorkerCacheInterface* worker_cache = worker_cache_;
1053 const string name = part.name;
1054 WorkerInterface* w = part.worker;
1055 CHECK_NOTNULL(w);
1056 auto cb = [worker_cache, c, name, w](const Status& s) {
1057 if (!s.ok()) {
1058 // This error is potentially benign, so we don't log at the
1059 // error level.
1060 LOG(INFO) << "DeregisterGraph error: " << s;
1061 }
1062 delete c;
1063 worker_cache->ReleaseWorker(name, w);
1064 };
1065 w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1066 }
1067 }
1068 }
1069
1070 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1071 void CopyAndSortStrings(size_t size,
1072 const std::function<string(size_t)>& input_accessor,
1073 protobuf::RepeatedPtrField<string>* output) {
1074 std::vector<string> temp;
1075 temp.reserve(size);
1076 for (size_t i = 0; i < size; ++i) {
1077 output->Add(input_accessor(i));
1078 }
1079 std::sort(output->begin(), output->end());
1080 }
1081 } // namespace
1082
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1083 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1084 const ConfigProto& config,
1085 BuildGraphOptions* opts) {
1086 CallableOptions* callable_opts = &opts->callable_options;
1087 CopyAndSortStrings(
1088 req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1089 callable_opts->mutable_feed());
1090 CopyAndSortStrings(
1091 req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1092 callable_opts->mutable_fetch());
1093 CopyAndSortStrings(
1094 req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1095 callable_opts->mutable_target());
1096
1097 if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1098 *callable_opts->mutable_run_options()->mutable_debug_options() =
1099 req.options().debug_options();
1100 }
1101
1102 opts->collective_graph_key =
1103 req.options().experimental().collective_graph_key();
1104 if (config.experimental().collective_deterministic_sequential_execution()) {
1105 opts->collective_order = GraphCollectiveOrder::kEdges;
1106 } else if (config.experimental().collective_nccl()) {
1107 opts->collective_order = GraphCollectiveOrder::kAttrs;
1108 }
1109 }
1110
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1111 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1112 BuildGraphOptions* opts) {
1113 CallableOptions* callable_opts = &opts->callable_options;
1114 CopyAndSortStrings(
1115 req.feed_size(), [&req](size_t i) { return req.feed(i); },
1116 callable_opts->mutable_feed());
1117 CopyAndSortStrings(
1118 req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1119 callable_opts->mutable_fetch());
1120 CopyAndSortStrings(
1121 req.target_size(), [&req](size_t i) { return req.target(i); },
1122 callable_opts->mutable_target());
1123
1124 // TODO(cais): Add TFDBG support to partial runs.
1125 }
1126
HashBuildGraphOptions(const BuildGraphOptions & opts)1127 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1128 uint64 h = 0x2b992ddfa23249d6ull;
1129 for (const string& name : opts.callable_options.feed()) {
1130 h = Hash64(name.c_str(), name.size(), h);
1131 }
1132 for (const string& name : opts.callable_options.target()) {
1133 h = Hash64(name.c_str(), name.size(), h);
1134 }
1135 for (const string& name : opts.callable_options.fetch()) {
1136 h = Hash64(name.c_str(), name.size(), h);
1137 }
1138
1139 const DebugOptions& debug_options =
1140 opts.callable_options.run_options().debug_options();
1141 if (!debug_options.debug_tensor_watch_opts().empty()) {
1142 const string watch_summary =
1143 SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1144 h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1145 }
1146
1147 return h;
1148 }
1149
BuildGraphOptionsString(const BuildGraphOptions & opts)1150 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1151 string buf;
1152 for (const string& name : opts.callable_options.feed()) {
1153 strings::StrAppend(&buf, " FdE: ", name);
1154 }
1155 strings::StrAppend(&buf, "\n");
1156 for (const string& name : opts.callable_options.target()) {
1157 strings::StrAppend(&buf, " TN: ", name);
1158 }
1159 strings::StrAppend(&buf, "\n");
1160 for (const string& name : opts.callable_options.fetch()) {
1161 strings::StrAppend(&buf, " FeE: ", name);
1162 }
1163 if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1164 strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1165 }
1166 strings::StrAppend(&buf, "\n");
1167 return buf;
1168 }
1169
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1170 MasterSession::MasterSession(
1171 const SessionOptions& opt, const MasterEnv* env,
1172 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1173 std::unique_ptr<WorkerCacheInterface> worker_cache,
1174 std::unique_ptr<DeviceSet> device_set,
1175 std::vector<string> filtered_worker_list,
1176 StatsPublisherFactory stats_publisher_factory)
1177 : session_opts_(opt),
1178 env_(env),
1179 handle_(strings::FpToString(random::New64())),
1180 remote_devs_(std::move(remote_devs)),
1181 worker_cache_(std::move(worker_cache)),
1182 devices_(std::move(device_set)),
1183 filtered_worker_list_(std::move(filtered_worker_list)),
1184 stats_publisher_factory_(std::move(stats_publisher_factory)),
1185 graph_version_(0),
1186 run_graphs_(5),
1187 partial_run_graphs_(5) {
1188 UpdateLastAccessTime();
1189 CHECK(devices_) << "device_set was null!";
1190
1191 VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1192 << " #remote " << remote_devs_->size();
1193
1194 LOG(INFO) << "Start master session " << handle_
1195 << " with config: " << session_opts_.config.ShortDebugString();
1196 }
1197
~MasterSession()1198 MasterSession::~MasterSession() {
1199 for (const auto& iter : run_graphs_) iter.second->Unref();
1200 for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1201 }
1202
UpdateLastAccessTime()1203 void MasterSession::UpdateLastAccessTime() {
1204 last_access_time_usec_.store(Env::Default()->NowMicros());
1205 }
1206
Create(GraphDef * graph_def,const WorkerCacheFactoryOptions & options)1207 Status MasterSession::Create(GraphDef* graph_def,
1208 const WorkerCacheFactoryOptions& options) {
1209 if (session_opts_.config.use_per_session_threads() ||
1210 session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1211 return errors::InvalidArgument(
1212 "Distributed session does not support session thread pool options.");
1213 }
1214 if (session_opts_.config.graph_options().place_pruned_graph()) {
1215 // TODO(b/29900832): Fix this or remove the option.
1216 LOG(WARNING) << "Distributed session does not support the "
1217 "place_pruned_graph option.";
1218 session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1219 }
1220
1221 GraphExecutionStateOptions execution_options;
1222 execution_options.device_set = devices_.get();
1223 execution_options.session_options = &session_opts_;
1224 {
1225 mutex_lock l(mu_);
1226 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1227 graph_def, execution_options, &execution_state_));
1228 }
1229 should_delete_worker_sessions_ = true;
1230 return CreateWorkerSessions(options);
1231 }
1232
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1233 Status MasterSession::CreateWorkerSessions(
1234 const WorkerCacheFactoryOptions& options) {
1235 const std::vector<string> worker_names = filtered_worker_list_;
1236 WorkerCacheInterface* worker_cache = get_worker_cache();
1237
1238 struct WorkerGroup {
1239 // The worker name. (Not owned.)
1240 const string* name;
1241
1242 // The worker referenced by name. (Not owned.)
1243 WorkerInterface* worker = nullptr;
1244
1245 // Request and responses used for a given worker.
1246 CreateWorkerSessionRequest request;
1247 CreateWorkerSessionResponse response;
1248 Status status = Status::OK();
1249 };
1250 BlockingCounter done(worker_names.size());
1251 std::vector<WorkerGroup> workers(worker_names.size());
1252
1253 // Release the workers.
1254 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1255 for (auto&& worker_group : workers) {
1256 if (worker_group.worker != nullptr) {
1257 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1258 }
1259 }
1260 });
1261
1262 Status status = Status::OK();
1263 // Create all the workers & kick off the computations.
1264 for (size_t i = 0; i < worker_names.size(); ++i) {
1265 workers[i].name = &worker_names[i];
1266 workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
1267 workers[i].request.set_session_handle(handle_);
1268
1269 DeviceNameUtils::ParsedName name;
1270 if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1271 status = errors::Internal("Could not parse name ", worker_names[i]);
1272 LOG(WARNING) << status;
1273 return status;
1274 }
1275 if (!name.has_job || !name.has_task) {
1276 status = errors::Internal("Incomplete worker name ", worker_names[i]);
1277 LOG(WARNING) << status;
1278 return status;
1279 }
1280
1281 if (options.cluster_def) {
1282 *workers[i].request.mutable_server_def()->mutable_cluster() =
1283 *options.cluster_def;
1284 workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1285 workers[i].request.mutable_server_def()->set_job_name(name.job);
1286 workers[i].request.mutable_server_def()->set_task_index(name.task);
1287 // Session state is always isolated when ClusterSpec propagation
1288 // is in use.
1289 workers[i].request.set_isolate_session_state(true);
1290 } else {
1291 // NOTE(mrry): Do not set any component of the ServerDef,
1292 // because the worker will use its local configuration.
1293 workers[i].request.set_isolate_session_state(
1294 session_opts_.config.isolate_session_state());
1295 }
1296 }
1297
1298 for (size_t i = 0; i < worker_names.size(); ++i) {
1299 auto cb = [i, &workers, &done](const Status& s) {
1300 workers[i].status = s;
1301 done.DecrementCount();
1302 };
1303 workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1304 &workers[i].response, cb);
1305 }
1306
1307 done.Wait();
1308 for (size_t i = 0; i < workers.size(); ++i) {
1309 status.Update(workers[i].status);
1310 }
1311 return status;
1312 }
1313
DeleteWorkerSessions()1314 Status MasterSession::DeleteWorkerSessions() {
1315 WorkerCacheInterface* worker_cache = get_worker_cache();
1316 const std::vector<string>& worker_names = filtered_worker_list_;
1317
1318 struct WorkerGroup {
1319 // The worker name. (Not owned.)
1320 const string* name;
1321
1322 // The worker referenced by name. (Not owned.)
1323 WorkerInterface* worker = nullptr;
1324
1325 CallOptions call_opts;
1326
1327 // Request and responses used for a given worker.
1328 DeleteWorkerSessionRequest request;
1329 DeleteWorkerSessionResponse response;
1330 Status status = Status::OK();
1331 };
1332 BlockingCounter done(worker_names.size());
1333 std::vector<WorkerGroup> workers(worker_names.size());
1334
1335 // Release the workers.
1336 auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1337 for (auto&& worker_group : workers) {
1338 if (worker_group.worker != nullptr) {
1339 worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1340 }
1341 }
1342 });
1343
1344 Status status = Status::OK();
1345 // Create all the workers & kick off the computations.
1346 for (size_t i = 0; i < worker_names.size(); ++i) {
1347 workers[i].name = &worker_names[i];
1348 workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
1349 workers[i].request.set_session_handle(handle_);
1350 // Since the worker may have gone away, set a timeout to avoid blocking the
1351 // session-close operation.
1352 workers[i].call_opts.SetTimeout(10000);
1353 }
1354
1355 for (size_t i = 0; i < worker_names.size(); ++i) {
1356 auto cb = [i, &workers, &done](const Status& s) {
1357 workers[i].status = s;
1358 done.DecrementCount();
1359 };
1360 workers[i].worker->DeleteWorkerSessionAsync(
1361 &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1362 }
1363
1364 done.Wait();
1365 for (size_t i = 0; i < workers.size(); ++i) {
1366 status.Update(workers[i].status);
1367 }
1368 return status;
1369 }
1370
ListDevices(ListDevicesResponse * resp) const1371 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1372 if (worker_cache_) {
1373 // This is a ClusterSpec-propagated session, and thus env_->local_devices
1374 // are invalid.
1375
1376 // Mark the "client_device" as the sole local device.
1377 const Device* client_device = devices_->client_device();
1378 for (const Device* dev : devices_->devices()) {
1379 if (dev != client_device) {
1380 *(resp->add_remote_device()) = dev->attributes();
1381 }
1382 }
1383 *(resp->add_local_device()) = client_device->attributes();
1384 } else {
1385 for (Device* dev : env_->local_devices) {
1386 *(resp->add_local_device()) = dev->attributes();
1387 }
1388 for (auto&& dev : *remote_devs_) {
1389 *(resp->add_local_device()) = dev->attributes();
1390 }
1391 }
1392 return Status::OK();
1393 }
1394
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1395 Status MasterSession::Extend(const ExtendSessionRequest* req,
1396 ExtendSessionResponse* resp) {
1397 UpdateLastAccessTime();
1398 std::unique_ptr<GraphExecutionState> extended_execution_state;
1399 {
1400 mutex_lock l(mu_);
1401 if (closed_) {
1402 return errors::FailedPrecondition("Session is closed.");
1403 }
1404
1405 if (graph_version_ != req->current_graph_version()) {
1406 return errors::Aborted("Current version is ", graph_version_,
1407 " but caller expected ",
1408 req->current_graph_version(), ".");
1409 }
1410
1411 CHECK(execution_state_);
1412 TF_RETURN_IF_ERROR(
1413 execution_state_->Extend(req->graph_def(), &extended_execution_state));
1414
1415 CHECK(extended_execution_state);
1416 // The old execution state will be released outside the lock.
1417 execution_state_.swap(extended_execution_state);
1418 ++graph_version_;
1419 resp->set_new_graph_version(graph_version_);
1420 }
1421 return Status::OK();
1422 }
1423
get_worker_cache() const1424 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1425 if (worker_cache_) {
1426 return worker_cache_.get();
1427 }
1428 return env_->worker_cache;
1429 }
1430
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1431 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1432 ReffedClientGraph** out_rcg, int64* out_count) {
1433 const uint64 hash = HashBuildGraphOptions(opts);
1434 {
1435 mutex_lock l(mu_);
1436 // TODO(suharshs): We cache partial run graphs and run graphs separately
1437 // because there is preprocessing that needs to only be run for partial
1438 // run calls.
1439 RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1440 auto iter = m->find(hash);
1441 if (iter == m->end()) {
1442 // We have not seen this subgraph before. Build the subgraph and
1443 // cache it.
1444 VLOG(1) << "Unseen hash " << hash << " for "
1445 << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1446 << "\n";
1447 std::unique_ptr<ClientGraph> client_graph;
1448 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1449 WorkerCacheInterface* worker_cache = get_worker_cache();
1450 auto entry = new ReffedClientGraph(
1451 handle_, opts, std::move(client_graph), session_opts_,
1452 stats_publisher_factory_, is_partial, worker_cache,
1453 !should_delete_worker_sessions_);
1454 iter = m->insert({hash, entry}).first;
1455 VLOG(1) << "Preparing to execute new graph";
1456 }
1457 *out_rcg = iter->second;
1458 (*out_rcg)->Ref();
1459 *out_count = (*out_rcg)->get_and_increment_execution_count();
1460 }
1461 return Status::OK();
1462 }
1463
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1464 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1465 RCGMap* rcg_map) {
1466 VLOG(1) << "Discarding all reffed graphs";
1467 for (auto p : *rcg_map) {
1468 ReffedClientGraph* rcg = p.second;
1469 if (to_unref) {
1470 to_unref->push_back(rcg);
1471 } else {
1472 rcg->Unref();
1473 }
1474 }
1475 rcg_map->clear();
1476 }
1477
NewStepId(int64 graph_key)1478 uint64 MasterSession::NewStepId(int64 graph_key) {
1479 if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1480 // StepId must leave the most-significant 7 bits empty for future use.
1481 return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1482 } else {
1483 uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1484 int32 retry_count = 0;
1485 while (step_id == CollectiveExecutor::kInvalidId) {
1486 Notification note;
1487 Status status;
1488 env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1489 graph_key, [&status, ¬e](const Status& s) {
1490 status = s;
1491 note.Notify();
1492 });
1493 note.WaitForNotification();
1494 if (!status.ok()) {
1495 LOG(ERROR) << "Bad status from "
1496 "collective_executor_mgr->RefreshStepIdSequence: "
1497 << status << ". Retrying.";
1498 int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1499 Env::Default()->SleepForMicroseconds(delay_micros);
1500 } else {
1501 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1502 }
1503 }
1504 return step_id;
1505 }
1506 }
1507
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1508 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1509 PartialRunSetupResponse* resp) {
1510 std::vector<string> inputs, outputs, targets;
1511 for (const auto& feed : req->feed()) {
1512 inputs.push_back(feed);
1513 }
1514 for (const auto& fetch : req->fetch()) {
1515 outputs.push_back(fetch);
1516 }
1517 for (const auto& target : req->target()) {
1518 targets.push_back(target);
1519 }
1520
1521 string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1522
1523 ReffedClientGraph* rcg = nullptr;
1524
1525 // Prepare.
1526 BuildGraphOptions opts;
1527 BuildBuildGraphOptions(*req, &opts);
1528 int64 count = 0;
1529 TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1530
1531 rcg->Ref();
1532 RunState* run_state =
1533 new RunState(inputs, outputs, rcg,
1534 NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1535 {
1536 mutex_lock l(mu_);
1537 partial_runs_.emplace(
1538 std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1539 }
1540
1541 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1542
1543 resp->set_partial_run_handle(handle);
1544 return Status::OK();
1545 }
1546
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1547 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1548 MutableRunStepResponseWrapper* resp) {
1549 UpdateLastAccessTime();
1550 {
1551 mutex_lock l(mu_);
1552 if (closed_) {
1553 return errors::FailedPrecondition("Session is closed.");
1554 }
1555 ++num_running_;
1556 // Note: all code paths must eventually call MarkRunCompletion()
1557 // in order to appropriate decrement the num_running_ counter.
1558 }
1559 Status status;
1560 if (!req.partial_run_handle().empty()) {
1561 status = DoPartialRun(opts, req, resp);
1562 } else {
1563 status = DoRunWithLocalExecution(opts, req, resp);
1564 }
1565 return status;
1566 }
1567
1568 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1569 void MasterSession::MarkRunCompletion() {
1570 mutex_lock l(mu_);
1571 --num_running_;
1572 if (num_running_ == 0) {
1573 num_running_is_zero_.notify_all();
1574 }
1575 }
1576
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1577 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1578 // Registers subgraphs if haven't done so.
1579 PartitionOptions popts;
1580 popts.node_to_loc = SplitByWorker;
1581 // The closures popts.{new_name,get_incarnation} are called synchronously in
1582 // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1583 // "this" alive during the closure.
1584 popts.new_name = [this](const string& prefix) {
1585 mutex_lock l(mu_);
1586 return strings::StrCat(prefix, "_S", next_node_id_++);
1587 };
1588 popts.get_incarnation = [this](const string& name) -> int64 {
1589 Device* d = devices_->FindDeviceByName(name);
1590 if (d == nullptr) {
1591 return PartitionOptions::kIllegalIncarnation;
1592 } else {
1593 return d->attributes().incarnation();
1594 }
1595 };
1596 popts.control_flow_added = false;
1597 const bool enable_bfloat16_sendrecv =
1598 session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1599 popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1600 if (e->IsControlEdge()) {
1601 return DT_FLOAT;
1602 }
1603 DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1604 if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1605 return DT_BFLOAT16;
1606 } else {
1607 return dtype;
1608 }
1609 };
1610 if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1611 popts.scheduling_for_recvs = true;
1612 popts.need_to_record_start_times = true;
1613 }
1614
1615 TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1616
1617 return Status::OK();
1618 }
1619
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1620 Status MasterSession::DoPartialRun(CallOptions* opts,
1621 const RunStepRequestWrapper& req,
1622 MutableRunStepResponseWrapper* resp) {
1623 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1624 const string& prun_handle = req.partial_run_handle();
1625 RunState* run_state = nullptr;
1626 {
1627 mutex_lock l(mu_);
1628 auto it = partial_runs_.find(prun_handle);
1629 if (it == partial_runs_.end()) {
1630 return errors::InvalidArgument(
1631 "Must run PartialRunSetup before performing partial runs");
1632 }
1633 run_state = it->second.get();
1634 }
1635 // CollectiveOps are not supported in partial runs.
1636 if (req.options().experimental().collective_graph_key() !=
1637 BuildGraphOptions::kNoCollectiveGraphKey) {
1638 return errors::InvalidArgument(
1639 "PartialRun does not support Collective ops. collective_graph_key "
1640 "must be kNoCollectiveGraphKey.");
1641 }
1642
1643 // If this is the first partial run, initialize the PerStepState.
1644 if (!run_state->step_started) {
1645 run_state->step_started = true;
1646 PerStepState pss;
1647
1648 const auto count = run_state->count;
1649 pss.collect_timeline =
1650 req.options().trace_level() == RunOptions::FULL_TRACE;
1651 pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1652 pss.report_tensor_allocations_upon_oom =
1653 req.options().report_tensor_allocations_upon_oom();
1654
1655 // Build the cost model every 'build_cost_model_every' steps after skipping
1656 // an
1657 // initial 'build_cost_model_after' steps.
1658 const int64 build_cost_model_after =
1659 session_opts_.config.graph_options().build_cost_model_after();
1660 const int64 build_cost_model_every =
1661 session_opts_.config.graph_options().build_cost_model();
1662 pss.collect_costs =
1663 build_cost_model_every > 0 &&
1664 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1665 pss.collect_partition_graphs = req.options().output_partition_graphs();
1666
1667 std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1668 run_state->step_id, count, req.options());
1669 if (ph) {
1670 pss.collect_timeline = true;
1671 pss.collect_rpcs = ph->should_collect_rpcs();
1672 }
1673
1674 run_state->pss = std::move(pss);
1675 run_state->ph = std::move(ph);
1676 }
1677
1678 // Make sure that this is a new set of feeds that are still pending.
1679 for (size_t i = 0; i < req.num_feeds(); ++i) {
1680 const string& feed = req.feed_name(i);
1681 auto it = run_state->pending_inputs.find(feed);
1682 if (it == run_state->pending_inputs.end()) {
1683 return errors::InvalidArgument(
1684 "The feed ", feed, " was not specified in partial_run_setup.");
1685 } else if (it->second) {
1686 return errors::InvalidArgument("The feed ", feed,
1687 " has already been fed.");
1688 }
1689 }
1690 // Check that this is a new set of fetches that are still pending.
1691 for (size_t i = 0; i < req.num_fetches(); ++i) {
1692 const string& fetch = req.fetch_name(i);
1693 auto it = run_state->pending_outputs.find(fetch);
1694 if (it == run_state->pending_outputs.end()) {
1695 return errors::InvalidArgument(
1696 "The fetch ", fetch, " was not specified in partial_run_setup.");
1697 } else if (it->second) {
1698 return errors::InvalidArgument("The fetch ", fetch,
1699 " has already been fetched.");
1700 }
1701 }
1702
1703 // Ensure that the requested fetches can be computed from the provided feeds.
1704 {
1705 mutex_lock l(mu_);
1706 TF_RETURN_IF_ERROR(
1707 run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1708 }
1709
1710 // Determine if this partial run satisfies all the pending inputs and outputs.
1711 for (size_t i = 0; i < req.num_feeds(); ++i) {
1712 auto it = run_state->pending_inputs.find(req.feed_name(i));
1713 it->second = true;
1714 }
1715 for (size_t i = 0; i < req.num_fetches(); ++i) {
1716 auto it = run_state->pending_outputs.find(req.fetch_name(i));
1717 it->second = true;
1718 }
1719 bool is_last_partial_run = run_state->PendingDone();
1720
1721 Status s = run_state->rcg->RunPartitions(
1722 env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1723 resp, &cancellation_manager_, is_last_partial_run);
1724
1725 // Delete the run state if there is an error or all fetches are done.
1726 if (!s.ok() || is_last_partial_run) {
1727 ReffedClientGraph* rcg = run_state->rcg;
1728 run_state->pss.end_micros = Env::Default()->NowMicros();
1729 // Schedule post-processing and cleanup to be done asynchronously.
1730 Ref();
1731 rcg->Ref();
1732 rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1733 req.options(), resp->mutable_metadata());
1734 cleanup.release(); // MarkRunCompletion called in done closure.
1735 rcg->CleanupPartitionsAsync(
1736 run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1737 if (!s.ok()) {
1738 LOG(ERROR) << "Cleanup partition error: " << s;
1739 }
1740 rcg->Unref();
1741 MarkRunCompletion();
1742 Unref();
1743 });
1744 mutex_lock l(mu_);
1745 partial_runs_.erase(prun_handle);
1746 }
1747 return s;
1748 }
1749
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64 rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1750 Status MasterSession::CreateDebuggerState(
1751 const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1752 int64 rcg_execution_count,
1753 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1754 TF_RETURN_IF_ERROR(
1755 DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1756
1757 std::vector<string> input_names;
1758 for (size_t i = 0; i < req.num_feeds(); ++i) {
1759 input_names.push_back(req.feed_name(i));
1760 }
1761 std::vector<string> output_names;
1762 for (size_t i = 0; i < req.num_fetches(); ++i) {
1763 output_names.push_back(req.fetch_name(i));
1764 }
1765 std::vector<string> target_names;
1766 for (size_t i = 0; i < req.num_targets(); ++i) {
1767 target_names.push_back(req.target_name(i));
1768 }
1769
1770 // TODO(cais): We currently use -1 as a dummy value for session run count.
1771 // While this counter value is straightforward to define and obtain for
1772 // DirectSessions, it is less so for non-direct Sessions. Devise a better
1773 // way to get its value when the need arises.
1774 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1775 debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1776 input_names, output_names, target_names));
1777
1778 return Status::OK();
1779 }
1780
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64 count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1781 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1782 const RunOptions& run_options,
1783 uint64 step_id, int64 count,
1784 PerStepState* out_pss,
1785 std::unique_ptr<ProfileHandler>* out_ph) {
1786 out_pss->collect_timeline =
1787 run_options.trace_level() == RunOptions::FULL_TRACE;
1788 out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1789 out_pss->report_tensor_allocations_upon_oom =
1790 run_options.report_tensor_allocations_upon_oom();
1791 // Build the cost model every 'build_cost_model_every' steps after skipping an
1792 // initial 'build_cost_model_after' steps.
1793 const int64 build_cost_model_after =
1794 session_opts_.config.graph_options().build_cost_model_after();
1795 const int64 build_cost_model_every =
1796 session_opts_.config.graph_options().build_cost_model();
1797 out_pss->collect_costs =
1798 build_cost_model_every > 0 &&
1799 ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1800 out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1801
1802 *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1803 if (*out_ph) {
1804 out_pss->collect_timeline = true;
1805 out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1806 }
1807 }
1808
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1809 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1810 uint64 step_id,
1811 const RunOptions& run_options,
1812 PerStepState* pss,
1813 const std::unique_ptr<ProfileHandler>& ph,
1814 const Status& run_status,
1815 RunMetadata* out_run_metadata) {
1816 Status s = run_status;
1817 if (s.ok()) {
1818 pss->end_micros = Env::Default()->NowMicros();
1819 if (rcg->collective_graph_key() !=
1820 BuildGraphOptions::kNoCollectiveGraphKey) {
1821 env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1822 step_id);
1823 }
1824 // Schedule post-processing and cleanup to be done asynchronously.
1825 rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1826 } else if (errors::IsCancelled(s)) {
1827 mutex_lock l(mu_);
1828 if (closed_) {
1829 if (garbage_collected_) {
1830 s = errors::Cancelled(
1831 "Step was cancelled because the session was garbage collected due "
1832 "to inactivity.");
1833 } else {
1834 s = errors::Cancelled(
1835 "Step was cancelled by an explicit call to `Session::Close()`.");
1836 }
1837 }
1838 }
1839 Ref();
1840 rcg->Ref();
1841 rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1842 if (!s.ok()) {
1843 LOG(ERROR) << "Cleanup partition error: " << s;
1844 }
1845 rcg->Unref();
1846 MarkRunCompletion();
1847 Unref();
1848 });
1849 return s;
1850 }
1851
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1852 Status MasterSession::DoRunWithLocalExecution(
1853 CallOptions* opts, const RunStepRequestWrapper& req,
1854 MutableRunStepResponseWrapper* resp) {
1855 VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1856 PerStepState pss;
1857 pss.start_micros = Env::Default()->NowMicros();
1858 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1859
1860 // Prepare.
1861 BuildGraphOptions bgopts;
1862 BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1863 ReffedClientGraph* rcg = nullptr;
1864 int64 count;
1865 TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1866
1867 // Unref "rcg" when out of scope.
1868 core::ScopedUnref unref(rcg);
1869
1870 std::unique_ptr<DebuggerStateInterface> debugger_state;
1871 const DebugOptions& debug_options = req.options().debug_options();
1872
1873 if (!debug_options.debug_tensor_watch_opts().empty()) {
1874 TF_RETURN_IF_ERROR(
1875 CreateDebuggerState(debug_options, req, count, &debugger_state));
1876 }
1877 TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1878
1879 // Keeps the highest 8 bits 0x01: we reserve some bits of the
1880 // step_id for future use.
1881 uint64 step_id = NewStepId(rcg->collective_graph_key());
1882 TRACEPRINTF("stepid %llu", step_id);
1883
1884 std::unique_ptr<ProfileHandler> ph;
1885 FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1886
1887 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1888 &cancellation_manager_, false);
1889
1890 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
1891 return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1892 resp->mutable_metadata());
1893 }
1894
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1895 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1896 MakeCallableResponse* resp) {
1897 UpdateLastAccessTime();
1898
1899 BuildGraphOptions opts;
1900 opts.callable_options = req.options();
1901 opts.use_function_convention = false;
1902
1903 ReffedClientGraph* callable;
1904
1905 {
1906 mutex_lock l(mu_);
1907 if (closed_) {
1908 return errors::FailedPrecondition("Session is closed.");
1909 }
1910 std::unique_ptr<ClientGraph> client_graph;
1911 TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1912 callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
1913 session_opts_, stats_publisher_factory_,
1914 false /* is_partial */, get_worker_cache(),
1915 !should_delete_worker_sessions_);
1916 }
1917
1918 Status s = BuildAndRegisterPartitions(callable);
1919 if (!s.ok()) {
1920 callable->Unref();
1921 return s;
1922 }
1923
1924 uint64 handle;
1925 {
1926 mutex_lock l(mu_);
1927 handle = next_callable_handle_++;
1928 callables_[handle] = callable;
1929 }
1930
1931 resp->set_handle(handle);
1932 return Status::OK();
1933 }
1934
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)1935 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
1936 const RunCallableRequest& req,
1937 RunCallableResponse* resp) {
1938 VLOG(2) << "DoRunCallable req: " << req.DebugString();
1939 PerStepState pss;
1940 pss.start_micros = Env::Default()->NowMicros();
1941 auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1942
1943 // Prepare.
1944 int64 count = rcg->get_and_increment_execution_count();
1945
1946 const uint64 step_id = NewStepId(rcg->collective_graph_key());
1947 TRACEPRINTF("stepid %llu", step_id);
1948
1949 const RunOptions& run_options = rcg->callable_options().run_options();
1950
1951 if (run_options.timeout_in_ms() != 0) {
1952 opts->SetTimeout(run_options.timeout_in_ms());
1953 }
1954
1955 std::unique_ptr<ProfileHandler> ph;
1956 FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
1957 Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1958 &cancellation_manager_);
1959 cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
1960 return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
1961 resp->mutable_metadata());
1962 }
1963
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)1964 Status MasterSession::RunCallable(CallOptions* opts,
1965 const RunCallableRequest& req,
1966 RunCallableResponse* resp) {
1967 UpdateLastAccessTime();
1968 ReffedClientGraph* callable;
1969 {
1970 mutex_lock l(mu_);
1971 if (closed_) {
1972 return errors::FailedPrecondition("Session is closed.");
1973 }
1974 int64 handle = req.handle();
1975 if (handle >= next_callable_handle_) {
1976 return errors::InvalidArgument("No such callable handle: ", handle);
1977 }
1978 auto iter = callables_.find(req.handle());
1979 if (iter == callables_.end()) {
1980 return errors::InvalidArgument(
1981 "Attempted to run callable after handle was released: ", handle);
1982 }
1983 callable = iter->second;
1984 callable->Ref();
1985 ++num_running_;
1986 }
1987 core::ScopedUnref unref_callable(callable);
1988 return DoRunCallable(opts, callable, req, resp);
1989 }
1990
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)1991 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
1992 ReleaseCallableResponse* resp) {
1993 UpdateLastAccessTime();
1994 ReffedClientGraph* to_unref = nullptr;
1995 {
1996 mutex_lock l(mu_);
1997 auto iter = callables_.find(req.handle());
1998 if (iter != callables_.end()) {
1999 to_unref = iter->second;
2000 callables_.erase(iter);
2001 }
2002 }
2003 if (to_unref != nullptr) {
2004 to_unref->Unref();
2005 }
2006 return Status::OK();
2007 }
2008
Close()2009 Status MasterSession::Close() {
2010 {
2011 mutex_lock l(mu_);
2012 closed_ = true; // All subsequent calls to Run() or Extend() will fail.
2013 }
2014 cancellation_manager_.StartCancel();
2015 std::vector<ReffedClientGraph*> to_unref;
2016 {
2017 mutex_lock l(mu_);
2018 while (num_running_ != 0) {
2019 num_running_is_zero_.wait(l);
2020 }
2021 ClearRunsTable(&to_unref, &run_graphs_);
2022 ClearRunsTable(&to_unref, &partial_run_graphs_);
2023 ClearRunsTable(&to_unref, &callables_);
2024 }
2025 for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2026 if (should_delete_worker_sessions_) {
2027 Status s = DeleteWorkerSessions();
2028 if (!s.ok()) {
2029 LOG(WARNING) << s;
2030 }
2031 }
2032 return Status::OK();
2033 }
2034
GarbageCollect()2035 void MasterSession::GarbageCollect() {
2036 {
2037 mutex_lock l(mu_);
2038 closed_ = true;
2039 garbage_collected_ = true;
2040 }
2041 cancellation_manager_.StartCancel();
2042 Unref();
2043 }
2044
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64 count)2045 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2046 const std::vector<string>& output_names,
2047 ReffedClientGraph* rcg, const uint64 step_id,
2048 const int64 count)
2049 : rcg(rcg), step_id(step_id), count(count) {
2050 // Initially all the feeds and fetches are pending.
2051 for (auto& name : input_names) {
2052 pending_inputs[name] = false;
2053 }
2054 for (auto& name : output_names) {
2055 pending_outputs[name] = false;
2056 }
2057 }
2058
~RunState()2059 MasterSession::RunState::~RunState() {
2060 if (rcg) rcg->Unref();
2061 }
2062
PendingDone() const2063 bool MasterSession::RunState::PendingDone() const {
2064 for (const auto& it : pending_inputs) {
2065 if (!it.second) return false;
2066 }
2067 for (const auto& it : pending_outputs) {
2068 if (!it.second) return false;
2069 }
2070 return true;
2071 }
2072
2073 } // end namespace tensorflow
2074