• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/request_id.h"
28 #include "tensorflow/core/distributed_runtime/scheduler.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/framework/allocation_description.pb.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/cost_graph.pb.h"
34 #include "tensorflow/core/framework/graph_def_util.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_description.pb.h"
40 #include "tensorflow/core/graph/graph_partition.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/blocking_counter.h"
43 #include "tensorflow/core/lib/core/notification.h"
44 #include "tensorflow/core/lib/core/refcount.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/tracing.h"
59 #include "tensorflow/core/public/session_options.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61 
62 namespace tensorflow {
63 
64 // MasterSession wraps ClientGraph in a reference counted object.
65 // This way, MasterSession can clear up the cache mapping Run requests to
66 // compiled graphs while the compiled graph is still being used.
67 //
68 // TODO(zhifengc): Cleanup this class. It's becoming messy.
69 class MasterSession::ReffedClientGraph : public core::RefCounted {
70  public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)71   ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
72                     std::unique_ptr<ClientGraph> client_graph,
73                     const SessionOptions& session_opts,
74                     const StatsPublisherFactory& stats_publisher_factory,
75                     bool is_partial, WorkerCacheInterface* worker_cache,
76                     bool should_deregister)
77       : session_handle_(handle),
78         bg_opts_(bopts),
79         client_graph_before_register_(std::move(client_graph)),
80         session_opts_(session_opts),
81         is_partial_(is_partial),
82         callable_opts_(bopts.callable_options),
83         worker_cache_(worker_cache),
84         should_deregister_(should_deregister),
85         collective_graph_key_(
86             client_graph_before_register_->collective_graph_key) {
87     VLOG(1) << "Created ReffedClientGraph for node with "
88             << client_graph_before_register_->graph.num_node_ids();
89 
90     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
91 
92     // Initialize a name to node map for processing device stats.
93     for (Node* n : client_graph_before_register_->graph.nodes()) {
94       name_to_node_details_.emplace(
95           n->name(),
96           NodeDetails(n->type_string(),
97                       strings::StrCat(
98                           "(", absl::StrJoin(n->requested_inputs(), ", "))));
99     }
100   }
101 
~ReffedClientGraph()102   ~ReffedClientGraph() override {
103     if (should_deregister_) {
104       DeregisterPartitions();
105     } else {
106       for (Part& part : partitions_) {
107         worker_cache_->ReleaseWorker(part.name, part.worker);
108       }
109     }
110   }
111 
callable_options()112   const CallableOptions& callable_options() { return callable_opts_; }
113 
build_graph_options()114   const BuildGraphOptions& build_graph_options() { return bg_opts_; }
115 
collective_graph_key()116   int64 collective_graph_key() { return collective_graph_key_; }
117 
GetProfileHandler(uint64 step,int64 execution_count,const RunOptions & ropts)118   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
119                                                     int64 execution_count,
120                                                     const RunOptions& ropts) {
121     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
122   }
123 
get_and_increment_execution_count()124   int64 get_and_increment_execution_count() {
125     return execution_count_.fetch_add(1);
126   }
127 
128   // Turn RPC logging on or off, both at the WorkerCache used by this
129   // master process, and at each remote worker in use for the current
130   // partitions.
SetRPCLogging(bool active)131   void SetRPCLogging(bool active) {
132     worker_cache_->SetLogging(active);
133     // Logging is a best-effort activity, so we make async calls to turn
134     // it on/off and don't make use of the responses.
135     for (auto& p : partitions_) {
136       LoggingRequest* req = new LoggingRequest;
137       if (active) {
138         req->set_enable_rpc_logging(true);
139       } else {
140         req->set_disable_rpc_logging(true);
141       }
142       LoggingResponse* resp = new LoggingResponse;
143       Ref();
144       p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
145         delete req;
146         delete resp;
147         // ReffedClientGraph owns p.worker so we need to hold a ref to
148         // ensure that the method doesn't attempt to access p.worker after
149         // ReffedClient graph has deleted it.
150         // TODO(suharshs): Simplify this ownership model.
151         Unref();
152       });
153     }
154   }
155 
156   // Retrieve all RPC logs data accumulated for the current step, both
157   // from the local WorkerCache in use by this master process and from
158   // all the remote workers executing the remote partitions.
RetrieveLogs(int64 step_id,StepStats * ss)159   void RetrieveLogs(int64 step_id, StepStats* ss) {
160     // Get the local data first, because it sets *ss without merging.
161     worker_cache_->RetrieveLogs(step_id, ss);
162 
163     // Then merge in data from all the remote workers.
164     LoggingRequest req;
165     req.add_fetch_step_id(step_id);
166     int waiting_for = partitions_.size();
167     if (waiting_for > 0) {
168       mutex scoped_mu;
169       BlockingCounter all_done(waiting_for);
170       for (auto& p : partitions_) {
171         LoggingResponse* resp = new LoggingResponse;
172         p.worker->LoggingAsync(
173             &req, resp,
174             [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
175               {
176                 mutex_lock l(scoped_mu);
177                 if (s.ok()) {
178                   for (auto& lss : resp->step()) {
179                     if (step_id != lss.step_id()) {
180                       LOG(ERROR) << "Wrong step_id in LoggingResponse";
181                       continue;
182                     }
183                     ss->MergeFrom(lss.step_stats());
184                   }
185                 }
186                 delete resp;
187               }
188               // Must not decrement all_done until out of critical section where
189               // *ss is updated.
190               all_done.DecrementCount();
191             });
192       }
193       all_done.Wait();
194     }
195   }
196 
197   // Local execution methods.
198 
199   // Partitions the graph into subgraphs and registers them on
200   // workers.
201   Status RegisterPartitions(PartitionOptions popts);
202 
203   // Runs one step of all partitions.
204   Status RunPartitions(const MasterEnv* env, int64 step_id,
205                        int64 execution_count, PerStepState* pss,
206                        CallOptions* opts, const RunStepRequestWrapper& req,
207                        MutableRunStepResponseWrapper* resp,
208                        CancellationManager* cm, const bool is_last_partial_run);
209   Status RunPartitions(const MasterEnv* env, int64 step_id,
210                        int64 execution_count, PerStepState* pss,
211                        CallOptions* call_opts, const RunCallableRequest& req,
212                        RunCallableResponse* resp, CancellationManager* cm);
213 
214   // Calls workers to cleanup states for the step "step_id".  Calls
215   // `done` when all cleanup RPCs have completed.
216   void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
217 
218   // Post-processing of any runtime statistics gathered during execution.
219   void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
220                     const RunOptions& options, RunMetadata* resp);
221   void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
222                           bool is_rpc);
223   // Checks that the requested fetches can be computed from the provided feeds.
224   Status CheckFetches(const RunStepRequestWrapper& req,
225                       const RunState* run_state,
226                       GraphExecutionState* execution_state);
227 
228  private:
229   const string session_handle_;
230   const BuildGraphOptions bg_opts_;
231 
232   // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
233   std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
234   const SessionOptions session_opts_;
235   const bool is_partial_;
236   const CallableOptions callable_opts_;
237   WorkerCacheInterface* const worker_cache_;  // Not owned.
238 
239   struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails240     explicit NodeDetails(string type_string, string detail_text)
241         : type_string(std::move(type_string)),
242           detail_text(std::move(detail_text)) {}
243     const string type_string;
244     const string detail_text;
245   };
246   std::unordered_map<string, NodeDetails> name_to_node_details_;
247 
248   const bool should_deregister_;
249   const int64 collective_graph_key_;
250   std::atomic<int64> execution_count_ = {0};
251 
252   // Graph partitioned into per-location subgraphs.
253   struct Part {
254     // Worker name.
255     string name;
256 
257     // Maps feed names to rendezvous keys. Empty most of the time.
258     std::unordered_map<string, string> feed_key;
259 
260     // Maps rendezvous keys to fetch names. Empty most of the time.
261     std::unordered_map<string, string> key_fetch;
262 
263     // The interface to the worker. Owned.
264     WorkerInterface* worker = nullptr;
265 
266     // After registration with the worker, graph_handle identifies
267     // this partition on the worker.
268     string graph_handle;
269 
Parttensorflow::MasterSession::ReffedClientGraph::Part270     Part() : feed_key(3), key_fetch(3) {}
271   };
272 
273   // partitions_ is immutable after RegisterPartitions() call
274   // finishes.  RunPartitions() can access partitions_ safely without
275   // acquiring locks.
276   std::vector<Part> partitions_;
277 
278   mutable mutex mu_;
279 
280   // Partition initialization and registration only needs to happen
281   // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
282   // indicates the initialization is ongoing.
283   Notification init_done_;
284 
285   // init_result_ remembers the initialization error if any.
286   Status init_result_ TF_GUARDED_BY(mu_);
287 
288   std::unique_ptr<StatsPublisherInterface> stats_publisher_;
289 
DetailText(const NodeDetails & details,const NodeExecStats & stats)290   string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
291     int64 tot = 0;
292     for (auto& no : stats.output()) {
293       tot += no.tensor_description().allocation_description().requested_bytes();
294     }
295     string bytes;
296     if (tot >= 0.1 * 1048576.0) {
297       bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
298     }
299     return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
300                            details.detail_text);
301   }
302 
303   // Send/Recv nodes that are the result of client-added
304   // feeds and fetches must be tracked so that the tensors
305   // can be added to the local rendezvous.
306   static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
307                                    const PartitionOptions& popts);
308 
309   // The actual graph partitioning and registration implementation.
310   Status DoBuildPartitions(
311       PartitionOptions popts, ClientGraph* client_graph,
312       std::unordered_map<string, GraphDef>* out_partitions);
313   Status DoRegisterPartitions(
314       const PartitionOptions& popts,
315       std::unordered_map<string, GraphDef> graph_partitions);
316 
317   // Prepares a number of calls to workers. One call per partition.
318   // This is a generic method that handles Run, PartialRun, and RunCallable.
319   template <class FetchListType, class ClientRequestType,
320             class ClientResponseType>
321   Status RunPartitionsHelper(
322       const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
323       const FetchListType& fetches, const MasterEnv* env, int64 step_id,
324       int64 execution_count, PerStepState* pss, CallOptions* call_opts,
325       const ClientRequestType& req, ClientResponseType* resp,
326       CancellationManager* cm, bool is_last_partial_run);
327 
328   // Deregisters the partitions on the workers.  Called in the
329   // destructor and does not wait for the rpc completion.
330   void DeregisterPartitions();
331 
332   TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
333 };
334 
RegisterPartitions(PartitionOptions popts)335 Status MasterSession::ReffedClientGraph::RegisterPartitions(
336     PartitionOptions popts) {
337   {  // Ensure register once.
338     mu_.lock();
339     if (client_graph_before_register_) {
340       // The `ClientGraph` is no longer needed after partitions are registered.
341       // Since it can account for a large amount of memory, we consume it here,
342       // and it will be freed after concluding with registration.
343 
344       std::unique_ptr<ClientGraph> client_graph;
345       std::swap(client_graph_before_register_, client_graph);
346       mu_.unlock();
347       std::unordered_map<string, GraphDef> graph_defs;
348       popts.flib_def = client_graph->flib_def.get();
349       Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
350       if (s.ok()) {
351         // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
352         // valid after the call to DoRegisterPartitions begins, so
353         // `stats_publisher_` must make a copy if it wants to retain the
354         // GraphDef objects.
355         std::vector<const GraphDef*> graph_defs_for_publishing;
356         graph_defs_for_publishing.reserve(partitions_.size());
357         for (const auto& name_def : graph_defs) {
358           graph_defs_for_publishing.push_back(&name_def.second);
359         }
360         stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
361         s = DoRegisterPartitions(popts, std::move(graph_defs));
362       }
363       mu_.lock();
364       init_result_ = s;
365       init_done_.Notify();
366     } else {
367       mu_.unlock();
368       init_done_.WaitForNotification();
369       mu_.lock();
370     }
371     const Status result = init_result_;
372     mu_.unlock();
373     return result;
374   }
375 }
376 
SplitByWorker(const Node * node)377 static string SplitByWorker(const Node* node) {
378   string task;
379   string device;
380   CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
381                                          &device))
382       << "node: " << node->name() << " dev: " << node->assigned_device_name();
383   return task;
384 }
385 
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)386 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
387     Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
388   for (int i = 0; i < graph_def.node_size(); ++i) {
389     const NodeDef& ndef = graph_def.node(i);
390     const bool is_recv = ndef.op() == "_Recv";
391     const bool is_send = ndef.op() == "_Send";
392 
393     if (is_recv || is_send) {
394       // Only send/recv nodes that were added as feeds and fetches
395       // (client-terminated) should be tracked.  Other send/recv nodes
396       // are for transferring data between partitions / memory spaces.
397       bool client_terminated;
398       TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
399       if (client_terminated) {
400         string name;
401         TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
402         string send_device;
403         TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
404         string recv_device;
405         TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
406         uint64 send_device_incarnation;
407         TF_CHECK_OK(
408             GetNodeAttr(ndef, "send_device_incarnation",
409                         reinterpret_cast<int64*>(&send_device_incarnation)));
410         const string& key =
411             Rendezvous::CreateKey(send_device, send_device_incarnation,
412                                   recv_device, name, FrameAndIter(0, 0));
413 
414         if (is_recv) {
415           part->feed_key.insert({name, key});
416         } else {
417           part->key_fetch.insert({key, name});
418         }
419       }
420     }
421   }
422 }
423 
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)424 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
425     PartitionOptions popts, ClientGraph* client_graph,
426     std::unordered_map<string, GraphDef>* out_partitions) {
427   if (popts.need_to_record_start_times) {
428     CostModel cost_model(true);
429     cost_model.InitFromGraph(client_graph->graph);
430     // TODO(yuanbyu): Use the real cost model.
431     // execution_state_->MergeFromGlobal(&cost_model);
432     SlackAnalysis sa(&client_graph->graph, &cost_model);
433     sa.ComputeAsap(&popts.start_times);
434   }
435 
436   // Partition the graph.
437   return Partition(popts, &client_graph->graph, out_partitions);
438 }
439 
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)440 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
441     const PartitionOptions& popts,
442     std::unordered_map<string, GraphDef> graph_partitions) {
443   partitions_.reserve(graph_partitions.size());
444   Status s;
445   for (auto& name_def : graph_partitions) {
446     partitions_.emplace_back();
447     Part* part = &partitions_.back();
448     part->name = name_def.first;
449     TrackFeedsAndFetches(part, name_def.second, popts);
450     part->worker = worker_cache_->GetOrCreateWorker(part->name);
451     if (part->worker == nullptr) {
452       s = errors::NotFound("worker ", part->name);
453       break;
454     }
455   }
456   if (!s.ok()) {
457     for (Part& part : partitions_) {
458       worker_cache_->ReleaseWorker(part.name, part.worker);
459       part.worker = nullptr;
460     }
461     return s;
462   }
463   struct Call {
464     RegisterGraphRequest req;
465     RegisterGraphResponse resp;
466     Status status;
467   };
468   const int num = partitions_.size();
469   gtl::InlinedVector<Call, 4> calls(num);
470   BlockingCounter done(num);
471   for (int i = 0; i < num; ++i) {
472     const Part& part = partitions_[i];
473     Call* c = &calls[i];
474     c->req.set_session_handle(session_handle_);
475     c->req.set_create_worker_session_called(!should_deregister_);
476     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
477     StripDefaultAttributes(*OpRegistry::Global(),
478                            c->req.mutable_graph_def()->mutable_node());
479     *c->req.mutable_config_proto() = session_opts_.config;
480     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
481     *c->req.mutable_debug_options() =
482         callable_opts_.run_options().debug_options();
483     c->req.set_collective_graph_key(collective_graph_key_);
484     VLOG(2) << "Register " << c->req.graph_def().DebugString();
485     auto cb = [c, &done](const Status& s) {
486       c->status = s;
487       done.DecrementCount();
488     };
489     part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
490   }
491   done.Wait();
492   for (int i = 0; i < num; ++i) {
493     Call* c = &calls[i];
494     s.Update(c->status);
495     partitions_[i].graph_handle = c->resp.graph_handle();
496   }
497   return s;
498 }
499 
500 namespace {
501 // Helper class to manage "num" parallel RunGraph calls.
502 class RunManyGraphs {
503  public:
RunManyGraphs(int num)504   explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
505 
~RunManyGraphs()506   ~RunManyGraphs() {}
507 
508   // Returns the index-th call.
509   struct Call {
510     CallOptions opts;
511     const string* worker_name;
512     std::atomic<bool> done{false};
513     std::unique_ptr<MutableRunGraphRequestWrapper> req;
514     std::unique_ptr<MutableRunGraphResponseWrapper> resp;
515   };
get(int index)516   Call* get(int index) { return &calls_[index]; }
517 
518   // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)519   void WhenDone(int index, const Status& s) {
520     TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
521     Call* call = get(index);
522     call->done = true;
523     auto resp = call->resp.get();
524     if (resp->status_code() != error::Code::OK) {
525       // resp->status_code will only be non-OK if s.ok().
526       mutex_lock l(mu_);
527       ReportBadStatus(Status(resp->status_code(),
528                              strings::StrCat("From ", *call->worker_name, ":\n",
529                                              resp->status_error_message())));
530     } else if (!s.ok()) {
531       mutex_lock l(mu_);
532       ReportBadStatus(
533           Status(s.code(), strings::StrCat("From ", *call->worker_name, ":\n",
534                                            s.error_message())));
535     }
536     pending_.DecrementCount();
537   }
538 
StartCancel()539   void StartCancel() {
540     mutex_lock l(mu_);
541     ReportBadStatus(errors::Cancelled("RunManyGraphs"));
542   }
543 
Wait()544   void Wait() {
545     // Check the error status every 60 seconds in other to print a log message
546     // in the event of a hang.
547     const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
548     while (true) {
549       if (pending_.WaitFor(kCheckErrorPeriod)) {
550         return;
551       }
552       if (!status().ok()) {
553         break;
554       }
555     }
556 
557     // The step has failed. Wait for another 60 seconds before diagnosing a
558     // hang.
559     DCHECK(!status().ok());
560     if (pending_.WaitFor(kCheckErrorPeriod)) {
561       return;
562     }
563     LOG(ERROR)
564         << "RunStep still blocked after 60 seconds. Failed with error status: "
565         << status();
566     for (const Call& call : calls_) {
567       if (!call.done) {
568         LOG(ERROR) << "- No response from RunGraph call to worker: "
569                    << *call.worker_name;
570       }
571     }
572     pending_.Wait();
573   }
574 
status() const575   Status status() const {
576     mutex_lock l(mu_);
577     // Concat status objects in this StatusGroup to get the aggregated status,
578     // as each status in status_group_ is already summarized status.
579     return status_group_.as_concatenated_status();
580   }
581 
582  private:
583   gtl::InlinedVector<Call, 4> calls_;
584 
585   BlockingCounter pending_;
586   mutable mutex mu_;
587   StatusGroup status_group_ TF_GUARDED_BY(mu_);
588   bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
589 
ReportBadStatus(const Status & s)590   void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
591     VLOG(1) << "Master received error status " << s;
592     if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
593       // Only start cancelling other workers upon receiving a non-derived
594       // error
595       cancel_issued_ = true;
596 
597       VLOG(1) << "Master received error report. Cancelling remaining workers.";
598       for (Call& call : calls_) {
599         call.opts.StartCancel();
600       }
601     }
602 
603     status_group_.Update(s);
604   }
605 
606   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
607 };
608 
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)609 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
610                                 MutableRunGraphRequestWrapper* worker_req,
611                                 size_t index, const string& send_key) {
612   return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
613 }
614 
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)615 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
616                                 MutableRunGraphRequestWrapper* worker_req,
617                                 size_t index, const string& send_key) {
618   return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
619 }
620 
621 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
622 // in-process messages.
623 struct RunCallableResponseWrapper {
624   RunCallableResponse* resp;  // Not owned.
625   std::unordered_map<string, TensorProto> fetch_key_to_protos;
626 
mutable_metadatatensorflow::__anon4940402c0411::RunCallableResponseWrapper627   RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
628 
AddTensorFromRunGraphResponsetensorflow::__anon4940402c0411::RunCallableResponseWrapper629   Status AddTensorFromRunGraphResponse(
630       const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
631       size_t index) {
632     // TODO(b/74355905): Add a specialized implementation that avoids
633     // copying the tensor into the RunCallableResponse when at least
634     // two of the {client, master, worker} are in the same process.
635     return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
636   }
637 };
638 }  // namespace
639 
640 template <class FetchListType, class ClientRequestType,
641           class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)642 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
643     const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
644     const FetchListType& fetches, const MasterEnv* env, int64 step_id,
645     int64 execution_count, PerStepState* pss, CallOptions* call_opts,
646     const ClientRequestType& req, ClientResponseType* resp,
647     CancellationManager* cm, bool is_last_partial_run) {
648   // Collect execution cost stats on a smoothly decreasing frequency.
649   ExecutorOpts exec_opts;
650   if (pss->report_tensor_allocations_upon_oom) {
651     exec_opts.set_report_tensor_allocations_upon_oom(true);
652   }
653   if (pss->collect_costs) {
654     exec_opts.set_record_costs(true);
655   }
656   if (pss->collect_timeline) {
657     exec_opts.set_record_timeline(true);
658   }
659   if (pss->collect_rpcs) {
660     SetRPCLogging(true);
661   }
662   if (pss->collect_partition_graphs) {
663     exec_opts.set_record_partition_graphs(true);
664   }
665   if (pss->collect_costs || pss->collect_timeline) {
666     pss->step_stats.resize(partitions_.size());
667   }
668 
669   const int num = partitions_.size();
670   RunManyGraphs calls(num);
671 
672   for (int i = 0; i < num; ++i) {
673     const Part& part = partitions_[i];
674     RunManyGraphs::Call* c = calls.get(i);
675     c->worker_name = &part.name;
676     c->req.reset(part.worker->CreateRunGraphRequest());
677     c->resp.reset(part.worker->CreateRunGraphResponse());
678     if (is_partial_) {
679       c->req->set_is_partial(is_partial_);
680       c->req->set_is_last_partial_run(is_last_partial_run);
681     }
682     c->req->set_session_handle(session_handle_);
683     c->req->set_create_worker_session_called(!should_deregister_);
684     c->req->set_graph_handle(part.graph_handle);
685     c->req->set_step_id(step_id);
686     *c->req->mutable_exec_opts() = exec_opts;
687     c->req->set_store_errors_in_response_body(true);
688     c->req->set_request_id(GetUniqueRequestId());
689     // If any feeds are provided, send the feed values together
690     // in the RunGraph request.
691     // In the partial case, we only want to include feeds provided in the req.
692     // In the non-partial case, all feeds in the request are in the part.
693     // We keep these as separate paths for now, to ensure we aren't
694     // inadvertently slowing down the normal run path.
695     if (is_partial_) {
696       for (const auto& name_index : feeds) {
697         const auto iter = part.feed_key.find(string(name_index.first));
698         if (iter == part.feed_key.end()) {
699           // The provided feed must be for a different partition.
700           continue;
701         }
702         const string& key = iter->second;
703         TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
704                                                     name_index.second, key));
705       }
706       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
707       // For now, we just iterate through partitions to find the matching key.
708       for (const string& req_fetch : fetches) {
709         for (const auto& key_fetch : part.key_fetch) {
710           if (key_fetch.second == req_fetch) {
711             c->req->add_recv_key(key_fetch.first);
712             break;
713           }
714         }
715       }
716     } else {
717       for (const auto& feed_key : part.feed_key) {
718         const string& feed = feed_key.first;
719         const string& key = feed_key.second;
720         auto iter = feeds.find(feed);
721         if (iter == feeds.end()) {
722           return errors::Internal("No feed index found for feed: ", feed);
723         }
724         const int64 feed_index = iter->second;
725         TF_RETURN_IF_ERROR(
726             AddSendFromClientRequest(req, c->req.get(), feed_index, key));
727       }
728       for (const auto& key_fetch : part.key_fetch) {
729         const string& key = key_fetch.first;
730         c->req->add_recv_key(key);
731       }
732     }
733   }
734 
735   // Issues RunGraph calls.
736   for (int i = 0; i < num; ++i) {
737     const Part& part = partitions_[i];
738     RunManyGraphs::Call* call = calls.get(i);
739     TRACEPRINTF("Partition %d %s", i, part.name.c_str());
740     part.worker->RunGraphAsync(
741         &call->opts, call->req.get(), call->resp.get(),
742         std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
743   }
744 
745   // Waits for the RunGraph calls.
746   call_opts->SetCancelCallback([&calls]() {
747     LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
748                  "worker operations.";
749     calls.StartCancel();
750   });
751   auto token = cm->get_cancellation_token();
752   const bool success =
753       cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
754   if (!success) {
755     calls.StartCancel();
756   }
757   calls.Wait();
758   call_opts->ClearCancelCallback();
759   if (success) {
760     cm->DeregisterCallback(token);
761   } else {
762     return errors::Cancelled("Step was cancelled");
763   }
764   TF_RETURN_IF_ERROR(calls.status());
765 
766   // Collects fetches and metadata.
767   Status status;
768   for (int i = 0; i < num; ++i) {
769     const Part& part = partitions_[i];
770     MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
771     for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
772       auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
773       if (iter == part.key_fetch.end()) {
774         status.Update(errors::Internal("Unexpected fetch key: ",
775                                        run_graph_resp->recv_key(j)));
776         break;
777       }
778       const string& fetch = iter->second;
779       status.Update(
780           resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
781       if (!status.ok()) {
782         break;
783       }
784     }
785     if (pss->collect_timeline) {
786       pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
787     }
788     if (pss->collect_costs) {
789       CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
790       for (int j = 0; j < cost_graph->node_size(); ++j) {
791         resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
792             cost_graph->mutable_node(j));
793       }
794     }
795     if (pss->collect_partition_graphs) {
796       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
797           resp->mutable_metadata()->mutable_partition_graphs();
798       for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
799         partition_graph_defs->Add()->Swap(
800             run_graph_resp->mutable_partition_graph(i));
801       }
802     }
803   }
804   return status;
805 }
806 
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)807 Status MasterSession::ReffedClientGraph::RunPartitions(
808     const MasterEnv* env, int64 step_id, int64 execution_count,
809     PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
810     MutableRunStepResponseWrapper* resp, CancellationManager* cm,
811     const bool is_last_partial_run) {
812   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
813           << execution_count;
814   // Maps the names of fed tensors to their index in `req`.
815   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
816   for (size_t i = 0; i < req.num_feeds(); ++i) {
817     if (!feeds.insert({req.feed_name(i), i}).second) {
818       return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
819     }
820   }
821 
822   std::vector<string> fetches;
823   fetches.reserve(req.num_fetches());
824   for (size_t i = 0; i < req.num_fetches(); ++i) {
825     fetches.push_back(req.fetch_name(i));
826   }
827 
828   return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
829                              call_opts, req, resp, cm, is_last_partial_run);
830 }
831 
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)832 Status MasterSession::ReffedClientGraph::RunPartitions(
833     const MasterEnv* env, int64 step_id, int64 execution_count,
834     PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
835     RunCallableResponse* resp, CancellationManager* cm) {
836   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
837           << execution_count;
838   // Maps the names of fed tensors to their index in `req`.
839   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
840   for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
841     if (!feeds.insert({callable_opts_.feed(i), i}).second) {
842       // MakeCallable will fail if there are two feeds with the same name.
843       return errors::Internal("Duplicated feeds in callable: ",
844                               callable_opts_.feed(i));
845     }
846   }
847 
848   // Create a wrapped response object to collect the fetched values and
849   // rearrange them for the RunCallableResponse.
850   RunCallableResponseWrapper wrapped_resp;
851   wrapped_resp.resp = resp;
852 
853   TF_RETURN_IF_ERROR(RunPartitionsHelper(
854       feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
855       call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
856 
857   // Collects fetches.
858   // TODO(b/74355905): Add a specialized implementation that avoids
859   // copying the tensor into the RunCallableResponse when at least
860   // two of the {client, master, worker} are in the same process.
861   for (const string& fetch : callable_opts_.fetch()) {
862     TensorProto* fetch_proto = resp->mutable_fetch()->Add();
863     auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
864     if (iter == wrapped_resp.fetch_key_to_protos.end()) {
865       return errors::Internal("Worker did not return a value for fetch: ",
866                               fetch);
867     }
868     fetch_proto->Swap(&iter->second);
869   }
870   return Status::OK();
871 }
872 
873 namespace {
874 
875 class CleanupBroadcastHelper {
876  public:
CleanupBroadcastHelper(int64 step_id,int num_calls,StatusCallback done)877   CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
878       : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
879     req_.set_step_id(step_id);
880   }
881 
882   // Returns a non-owned pointer to a request buffer for all calls.
request()883   CleanupGraphRequest* request() { return &req_; }
884 
885   // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)886   CleanupGraphResponse* response(int i) { return &resps_[i]; }
887 
888   // Called when the ith response is received.
call_done(int i,const Status & s)889   void call_done(int i, const Status& s) {
890     bool run_callback = false;
891     Status status_copy;
892     {
893       mutex_lock l(mu_);
894       status_.Update(s);
895       if (--num_pending_ == 0) {
896         run_callback = true;
897         status_copy = status_;
898       }
899     }
900     if (run_callback) {
901       done_(status_copy);
902       // This is the last call, so delete the helper object.
903       delete this;
904     }
905   }
906 
907  private:
908   // A single request shared between all workers.
909   CleanupGraphRequest req_;
910   // One response buffer for each worker.
911   gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
912 
913   mutex mu_;
914   // Number of requests remaining to be collected.
915   int num_pending_ TF_GUARDED_BY(mu_);
916   // Aggregate status of the operation.
917   Status status_ TF_GUARDED_BY(mu_);
918   // Callback to be called when all operations complete.
919   StatusCallback done_;
920 
921   TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
922 };
923 
924 }  // namespace
925 
CleanupPartitionsAsync(int64 step_id,StatusCallback done)926 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
927     int64 step_id, StatusCallback done) {
928   const int num = partitions_.size();
929   // Helper object will be deleted when the final call completes.
930   CleanupBroadcastHelper* helper =
931       new CleanupBroadcastHelper(step_id, num, std::move(done));
932   for (int i = 0; i < num; ++i) {
933     const Part& part = partitions_[i];
934     part.worker->CleanupGraphAsync(
935         helper->request(), helper->response(i),
936         [helper, i](const Status& s) { helper->call_done(i, s); });
937   }
938 }
939 
ProcessStats(int64 step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)940 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
941                                                     PerStepState* pss,
942                                                     ProfileHandler* ph,
943                                                     const RunOptions& options,
944                                                     RunMetadata* resp) {
945   if (!pss->collect_costs && !pss->collect_timeline) return;
946 
947   // Out-of-band logging data is collected now, during post-processing.
948   if (pss->collect_timeline) {
949     SetRPCLogging(false);
950     RetrieveLogs(step_id, &pss->rpc_stats);
951   }
952   for (size_t i = 0; i < partitions_.size(); ++i) {
953     const StepStats& ss = pss->step_stats[i];
954     if (ph) {
955       for (const auto& ds : ss.dev_stats()) {
956         ProcessDeviceStats(ph, ds, false /*is_rpc*/);
957       }
958     }
959   }
960   if (ph) {
961     for (const auto& ds : pss->rpc_stats.dev_stats()) {
962       ProcessDeviceStats(ph, ds, true /*is_rpc*/);
963     }
964     ph->StepDone(pss->start_micros, pss->end_micros,
965                  Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
966                  Status::OK());
967   }
968   // Assemble all stats for this timeline into a merged StepStats.
969   if (pss->collect_timeline) {
970     StepStats step_stats_proto;
971     step_stats_proto.Swap(&pss->rpc_stats);
972     for (size_t i = 0; i < partitions_.size(); ++i) {
973       step_stats_proto.MergeFrom(pss->step_stats[i]);
974       pss->step_stats[i].Clear();
975     }
976     pss->step_stats.clear();
977     // Copy the stats back, but only for on-demand profiling to avoid slowing
978     // down calls that trigger the automatic profiling.
979     if (options.trace_level() == RunOptions::FULL_TRACE) {
980       resp->mutable_step_stats()->Swap(&step_stats_proto);
981     } else {
982       // If FULL_TRACE, it can be fetched from Session API, no need for
983       // duplicated publishing.
984       stats_publisher_->PublishStatsProto(step_stats_proto);
985     }
986   }
987 }
988 
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)989 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
990     ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
991   const string& dev_name = ds.device();
992   VLOG(1) << "Device " << dev_name << " reports stats for "
993           << ds.node_stats_size() << " nodes";
994   for (const auto& ns : ds.node_stats()) {
995     if (is_rpc) {
996       // We don't have access to a good Node pointer, so we rely on
997       // sufficient data being present in the NodeExecStats.
998       ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
999                       ns.timeline_label());
1000     } else {
1001       auto iter = name_to_node_details_.find(ns.node_name());
1002       const bool found_node_in_graph = iter != name_to_node_details_.end();
1003       if (!found_node_in_graph && ns.timeline_label().empty()) {
1004         // The counter incrementing is not thread-safe. But we don't really
1005         // care.
1006         // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1007         // more general usage.
1008         static int log_counter = 0;
1009         if (log_counter < 10) {
1010           log_counter++;
1011           LOG(WARNING) << "Failed to find node " << ns.node_name()
1012                        << " for dev " << dev_name;
1013         }
1014         continue;
1015       }
1016       const string& optype =
1017           found_node_in_graph ? iter->second.type_string : ns.node_name();
1018       string details;
1019       if (!ns.timeline_label().empty()) {
1020         details = ns.timeline_label();
1021       } else if (found_node_in_graph) {
1022         details = DetailText(iter->second, ns);
1023       } else {
1024         // Leave details string empty
1025       }
1026       ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1027                       details);
1028     }
1029   }
1030 }
1031 
1032 // TODO(suharshs): Merge with CheckFetches in DirectSession.
1033 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1034 // on once at setup time to prevent us from computing the dependencies
1035 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)1036 Status MasterSession::ReffedClientGraph::CheckFetches(
1037     const RunStepRequestWrapper& req, const RunState* run_state,
1038     GraphExecutionState* execution_state) {
1039   // Build the set of pending feeds that we haven't seen.
1040   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1041   for (const auto& input : run_state->pending_inputs) {
1042     // Skip if already fed.
1043     if (input.second) continue;
1044     TensorId id(ParseTensorName(input.first));
1045     const Node* n = execution_state->get_node_by_name(string(id.first));
1046     if (n == nullptr) {
1047       return errors::NotFound("Feed ", input.first, ": not found");
1048     }
1049     pending_feeds.insert(id);
1050   }
1051   for (size_t i = 0; i < req.num_feeds(); ++i) {
1052     const TensorId id(ParseTensorName(req.feed_name(i)));
1053     pending_feeds.erase(id);
1054   }
1055 
1056   // Initialize the stack with the fetch nodes.
1057   std::vector<const Node*> stack;
1058   for (size_t i = 0; i < req.num_fetches(); ++i) {
1059     const string& fetch = req.fetch_name(i);
1060     const TensorId id(ParseTensorName(fetch));
1061     const Node* n = execution_state->get_node_by_name(string(id.first));
1062     if (n == nullptr) {
1063       return errors::NotFound("Fetch ", fetch, ": not found");
1064     }
1065     stack.push_back(n);
1066   }
1067 
1068   // Any tensor needed for fetches can't be in pending_feeds.
1069   // We need to use the original full graph from execution state.
1070   const Graph* graph = execution_state->full_graph();
1071   std::vector<bool> visited(graph->num_node_ids(), false);
1072   while (!stack.empty()) {
1073     const Node* n = stack.back();
1074     stack.pop_back();
1075 
1076     for (const Edge* in_edge : n->in_edges()) {
1077       const Node* in_node = in_edge->src();
1078       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1079         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1080                                        in_edge->src_output(),
1081                                        " can't be computed from the feeds"
1082                                        " that have been fed so far.");
1083       }
1084       if (!visited[in_node->id()]) {
1085         visited[in_node->id()] = true;
1086         stack.push_back(in_node);
1087       }
1088     }
1089   }
1090   return Status::OK();
1091 }
1092 
1093 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1094 // result.
DeregisterPartitions()1095 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1096   struct Call {
1097     DeregisterGraphRequest req;
1098     DeregisterGraphResponse resp;
1099   };
1100   for (Part& part : partitions_) {
1101     // The graph handle may be empty if we failed during partition registration.
1102     if (!part.graph_handle.empty()) {
1103       Call* c = new Call;
1104       c->req.set_session_handle(session_handle_);
1105       c->req.set_create_worker_session_called(!should_deregister_);
1106       c->req.set_graph_handle(part.graph_handle);
1107       // NOTE(mrry): We must capture `worker_cache_` since `this`
1108       // could be deleted before the callback is called.
1109       WorkerCacheInterface* worker_cache = worker_cache_;
1110       const string name = part.name;
1111       WorkerInterface* w = part.worker;
1112       CHECK_NOTNULL(w);
1113       auto cb = [worker_cache, c, name, w](const Status& s) {
1114         if (!s.ok()) {
1115           // This error is potentially benign, so we don't log at the
1116           // error level.
1117           LOG(INFO) << "DeregisterGraph error: " << s;
1118         }
1119         delete c;
1120         worker_cache->ReleaseWorker(name, w);
1121       };
1122       w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1123     }
1124   }
1125 }
1126 
1127 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1128 void CopyAndSortStrings(size_t size,
1129                         const std::function<string(size_t)>& input_accessor,
1130                         protobuf::RepeatedPtrField<string>* output) {
1131   std::vector<string> temp;
1132   temp.reserve(size);
1133   for (size_t i = 0; i < size; ++i) {
1134     output->Add(input_accessor(i));
1135   }
1136   std::sort(output->begin(), output->end());
1137 }
1138 }  // namespace
1139 
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1140 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1141                             const ConfigProto& config,
1142                             BuildGraphOptions* opts) {
1143   CallableOptions* callable_opts = &opts->callable_options;
1144   CopyAndSortStrings(
1145       req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1146       callable_opts->mutable_feed());
1147   CopyAndSortStrings(
1148       req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1149       callable_opts->mutable_fetch());
1150   CopyAndSortStrings(
1151       req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1152       callable_opts->mutable_target());
1153 
1154   if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1155     *callable_opts->mutable_run_options()->mutable_debug_options() =
1156         req.options().debug_options();
1157   }
1158 
1159   opts->collective_graph_key =
1160       req.options().experimental().collective_graph_key();
1161   if (config.experimental().collective_deterministic_sequential_execution()) {
1162     opts->collective_order = GraphCollectiveOrder::kEdges;
1163   } else if (config.experimental().collective_nccl()) {
1164     opts->collective_order = GraphCollectiveOrder::kAttrs;
1165   }
1166 }
1167 
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1168 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1169                             BuildGraphOptions* opts) {
1170   CallableOptions* callable_opts = &opts->callable_options;
1171   CopyAndSortStrings(
1172       req.feed_size(), [&req](size_t i) { return req.feed(i); },
1173       callable_opts->mutable_feed());
1174   CopyAndSortStrings(
1175       req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1176       callable_opts->mutable_fetch());
1177   CopyAndSortStrings(
1178       req.target_size(), [&req](size_t i) { return req.target(i); },
1179       callable_opts->mutable_target());
1180 
1181   // TODO(cais): Add TFDBG support to partial runs.
1182 }
1183 
HashBuildGraphOptions(const BuildGraphOptions & opts)1184 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1185   uint64 h = 0x2b992ddfa23249d6ull;
1186   for (const string& name : opts.callable_options.feed()) {
1187     h = Hash64(name.c_str(), name.size(), h);
1188   }
1189   for (const string& name : opts.callable_options.target()) {
1190     h = Hash64(name.c_str(), name.size(), h);
1191   }
1192   for (const string& name : opts.callable_options.fetch()) {
1193     h = Hash64(name.c_str(), name.size(), h);
1194   }
1195 
1196   const DebugOptions& debug_options =
1197       opts.callable_options.run_options().debug_options();
1198   if (!debug_options.debug_tensor_watch_opts().empty()) {
1199     const string watch_summary =
1200         SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1201     h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1202   }
1203 
1204   return h;
1205 }
1206 
BuildGraphOptionsString(const BuildGraphOptions & opts)1207 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1208   string buf;
1209   for (const string& name : opts.callable_options.feed()) {
1210     strings::StrAppend(&buf, " FdE: ", name);
1211   }
1212   strings::StrAppend(&buf, "\n");
1213   for (const string& name : opts.callable_options.target()) {
1214     strings::StrAppend(&buf, " TN: ", name);
1215   }
1216   strings::StrAppend(&buf, "\n");
1217   for (const string& name : opts.callable_options.fetch()) {
1218     strings::StrAppend(&buf, " FeE: ", name);
1219   }
1220   if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1221     strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1222   }
1223   strings::StrAppend(&buf, "\n");
1224   return buf;
1225 }
1226 
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1227 MasterSession::MasterSession(
1228     const SessionOptions& opt, const MasterEnv* env,
1229     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1230     std::unique_ptr<WorkerCacheInterface> worker_cache,
1231     std::unique_ptr<DeviceSet> device_set,
1232     std::vector<string> filtered_worker_list,
1233     StatsPublisherFactory stats_publisher_factory)
1234     : session_opts_(opt),
1235       env_(env),
1236       handle_(strings::FpToString(random::New64())),
1237       remote_devs_(std::move(remote_devs)),
1238       worker_cache_(std::move(worker_cache)),
1239       devices_(std::move(device_set)),
1240       filtered_worker_list_(std::move(filtered_worker_list)),
1241       stats_publisher_factory_(std::move(stats_publisher_factory)),
1242       graph_version_(0),
1243       run_graphs_(5),
1244       partial_run_graphs_(5) {
1245   UpdateLastAccessTime();
1246   CHECK(devices_) << "device_set was null!";
1247 
1248   VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1249           << " #remote " << remote_devs_->size();
1250   VLOG(1) << "Start master session " << handle_
1251           << " with config: " << session_opts_.config.ShortDebugString();
1252 }
1253 
~MasterSession()1254 MasterSession::~MasterSession() {
1255   for (const auto& iter : run_graphs_) iter.second->Unref();
1256   for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1257 }
1258 
UpdateLastAccessTime()1259 void MasterSession::UpdateLastAccessTime() {
1260   last_access_time_usec_.store(Env::Default()->NowMicros());
1261 }
1262 
Create(GraphDef && graph_def,const WorkerCacheFactoryOptions & options)1263 Status MasterSession::Create(GraphDef&& graph_def,
1264                              const WorkerCacheFactoryOptions& options) {
1265   if (session_opts_.config.use_per_session_threads() ||
1266       session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1267     return errors::InvalidArgument(
1268         "Distributed session does not support session thread pool options.");
1269   }
1270   if (session_opts_.config.graph_options().place_pruned_graph()) {
1271     // TODO(b/29900832): Fix this or remove the option.
1272     LOG(WARNING) << "Distributed session does not support the "
1273                     "place_pruned_graph option.";
1274     session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1275   }
1276 
1277   GraphExecutionStateOptions execution_options;
1278   execution_options.device_set = devices_.get();
1279   execution_options.session_options = &session_opts_;
1280   {
1281     mutex_lock l(mu_);
1282     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1283         std::move(graph_def), execution_options, &execution_state_));
1284   }
1285   should_delete_worker_sessions_ = true;
1286   return CreateWorkerSessions(options);
1287 }
1288 
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1289 Status MasterSession::CreateWorkerSessions(
1290     const WorkerCacheFactoryOptions& options) {
1291   const std::vector<string> worker_names = filtered_worker_list_;
1292   WorkerCacheInterface* worker_cache = get_worker_cache();
1293 
1294   struct WorkerGroup {
1295     // The worker name. (Not owned.)
1296     const string* name;
1297 
1298     // The worker referenced by name. (Not owned.)
1299     WorkerInterface* worker = nullptr;
1300 
1301     // Request and responses used for a given worker.
1302     CreateWorkerSessionRequest request;
1303     CreateWorkerSessionResponse response;
1304     Status status = Status::OK();
1305   };
1306   BlockingCounter done(worker_names.size());
1307   std::vector<WorkerGroup> workers(worker_names.size());
1308 
1309   // Release the workers.
1310   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1311     for (auto&& worker_group : workers) {
1312       if (worker_group.worker != nullptr) {
1313         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1314       }
1315     }
1316   });
1317 
1318   string task_name;
1319   string local_device_name;
1320   DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1321                                    &task_name, &local_device_name);
1322   const int64 client_device_incarnation =
1323       devices_->client_device()->attributes().incarnation();
1324 
1325   Status status = Status::OK();
1326   // Create all the workers & kick off the computations.
1327   for (size_t i = 0; i < worker_names.size(); ++i) {
1328     workers[i].name = &worker_names[i];
1329     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1330     workers[i].request.set_session_handle(handle_);
1331     workers[i].request.set_master_task(task_name);
1332     workers[i].request.set_master_incarnation(client_device_incarnation);
1333     if (session_opts_.config.share_cluster_devices_in_session() ||
1334         session_opts_.config.experimental()
1335             .share_cluster_devices_in_session()) {
1336       for (const auto& remote_dev : devices_->devices()) {
1337         *workers[i].request.add_cluster_device_attributes() =
1338             remote_dev->attributes();
1339       }
1340 
1341       if (!session_opts_.config.share_cluster_devices_in_session() &&
1342           session_opts_.config.experimental()
1343               .share_cluster_devices_in_session()) {
1344         LOG(WARNING)
1345             << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1346                "been promoted to a non-experimental API. Please use "
1347                "ConfigProto.share_cluster_devices_in_session instead. The "
1348                "experimental option will be removed in the future.";
1349       }
1350     }
1351 
1352     DeviceNameUtils::ParsedName name;
1353     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1354       status = errors::Internal("Could not parse name ", worker_names[i]);
1355       LOG(WARNING) << status;
1356       return status;
1357     }
1358     if (!name.has_job || !name.has_task) {
1359       status = errors::Internal("Incomplete worker name ", worker_names[i]);
1360       LOG(WARNING) << status;
1361       return status;
1362     }
1363 
1364     if (options.cluster_def) {
1365       *workers[i].request.mutable_server_def()->mutable_cluster() =
1366           *options.cluster_def;
1367       workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1368       workers[i].request.mutable_server_def()->set_job_name(name.job);
1369       workers[i].request.mutable_server_def()->set_task_index(name.task);
1370       // Session state is always isolated when ClusterSpec propagation
1371       // is in use.
1372       workers[i].request.set_isolate_session_state(true);
1373     } else {
1374       // NOTE(mrry): Do not set any component of the ServerDef,
1375       // because the worker will use its local configuration.
1376       workers[i].request.set_isolate_session_state(
1377           session_opts_.config.isolate_session_state());
1378     }
1379     if (session_opts_.config.experimental()
1380             .share_session_state_in_clusterspec_propagation()) {
1381       // In a dynamic cluster, the ClusterSpec info is usually propagated by
1382       // master sessions. However, in data parallel training with multiple
1383       // masters
1384       // ("between-graph replication"), we need to disable isolation for
1385       // different worker sessions to update the same variables in PS tasks.
1386       workers[i].request.set_isolate_session_state(false);
1387     }
1388   }
1389 
1390   for (size_t i = 0; i < worker_names.size(); ++i) {
1391     auto cb = [i, &workers, &done](const Status& s) {
1392       workers[i].status = s;
1393       done.DecrementCount();
1394     };
1395     workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1396                                                 &workers[i].response, cb);
1397   }
1398 
1399   done.Wait();
1400   for (size_t i = 0; i < workers.size(); ++i) {
1401     status.Update(workers[i].status);
1402   }
1403   return status;
1404 }
1405 
DeleteWorkerSessions()1406 Status MasterSession::DeleteWorkerSessions() {
1407   WorkerCacheInterface* worker_cache = get_worker_cache();
1408   const std::vector<string>& worker_names = filtered_worker_list_;
1409 
1410   struct WorkerGroup {
1411     // The worker name. (Not owned.)
1412     const string* name;
1413 
1414     // The worker referenced by name. (Not owned.)
1415     WorkerInterface* worker = nullptr;
1416 
1417     CallOptions call_opts;
1418 
1419     // Request and responses used for a given worker.
1420     DeleteWorkerSessionRequest request;
1421     DeleteWorkerSessionResponse response;
1422     Status status = Status::OK();
1423   };
1424   BlockingCounter done(worker_names.size());
1425   std::vector<WorkerGroup> workers(worker_names.size());
1426 
1427   // Release the workers.
1428   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1429     for (auto&& worker_group : workers) {
1430       if (worker_group.worker != nullptr) {
1431         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1432       }
1433     }
1434   });
1435 
1436   Status status = Status::OK();
1437   // Create all the workers & kick off the computations.
1438   for (size_t i = 0; i < worker_names.size(); ++i) {
1439     workers[i].name = &worker_names[i];
1440     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1441     workers[i].request.set_session_handle(handle_);
1442     // Since the worker may have gone away, set a timeout to avoid blocking the
1443     // session-close operation.
1444     workers[i].call_opts.SetTimeout(10000);
1445   }
1446 
1447   for (size_t i = 0; i < worker_names.size(); ++i) {
1448     auto cb = [i, &workers, &done](const Status& s) {
1449       workers[i].status = s;
1450       done.DecrementCount();
1451     };
1452     workers[i].worker->DeleteWorkerSessionAsync(
1453         &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1454   }
1455 
1456   done.Wait();
1457   for (size_t i = 0; i < workers.size(); ++i) {
1458     status.Update(workers[i].status);
1459   }
1460   return status;
1461 }
1462 
ListDevices(ListDevicesResponse * resp) const1463 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1464   if (worker_cache_) {
1465     // This is a ClusterSpec-propagated session, and thus env_->local_devices
1466     // are invalid.
1467 
1468     // Mark the "client_device" as the sole local device.
1469     const Device* client_device = devices_->client_device();
1470     for (const Device* dev : devices_->devices()) {
1471       if (dev != client_device) {
1472         *(resp->add_remote_device()) = dev->attributes();
1473       }
1474     }
1475     *(resp->add_local_device()) = client_device->attributes();
1476   } else {
1477     for (Device* dev : env_->local_devices) {
1478       *(resp->add_local_device()) = dev->attributes();
1479     }
1480     for (auto&& dev : *remote_devs_) {
1481       *(resp->add_local_device()) = dev->attributes();
1482     }
1483   }
1484   return Status::OK();
1485 }
1486 
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1487 Status MasterSession::Extend(const ExtendSessionRequest* req,
1488                              ExtendSessionResponse* resp) {
1489   UpdateLastAccessTime();
1490   std::unique_ptr<GraphExecutionState> extended_execution_state;
1491   {
1492     mutex_lock l(mu_);
1493     if (closed_) {
1494       return errors::FailedPrecondition("Session is closed.");
1495     }
1496 
1497     if (graph_version_ != req->current_graph_version()) {
1498       return errors::Aborted("Current version is ", graph_version_,
1499                              " but caller expected ",
1500                              req->current_graph_version(), ".");
1501     }
1502 
1503     CHECK(execution_state_);
1504     TF_RETURN_IF_ERROR(
1505         execution_state_->Extend(req->graph_def(), &extended_execution_state));
1506 
1507     CHECK(extended_execution_state);
1508     // The old execution state will be released outside the lock.
1509     execution_state_.swap(extended_execution_state);
1510     ++graph_version_;
1511     resp->set_new_graph_version(graph_version_);
1512   }
1513   return Status::OK();
1514 }
1515 
get_worker_cache() const1516 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1517   if (worker_cache_) {
1518     return worker_cache_.get();
1519   }
1520   return env_->worker_cache;
1521 }
1522 
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1523 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1524                                 ReffedClientGraph** out_rcg, int64* out_count) {
1525   const uint64 hash = HashBuildGraphOptions(opts);
1526   {
1527     mutex_lock l(mu_);
1528     // TODO(suharshs): We cache partial run graphs and run graphs separately
1529     // because there is preprocessing that needs to only be run for partial
1530     // run calls.
1531     RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1532     auto iter = m->find(hash);
1533     if (iter == m->end()) {
1534       // We have not seen this subgraph before. Build the subgraph and
1535       // cache it.
1536       VLOG(1) << "Unseen hash " << hash << " for "
1537               << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1538               << "\n";
1539       std::unique_ptr<ClientGraph> client_graph;
1540       TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1541       WorkerCacheInterface* worker_cache = get_worker_cache();
1542       auto entry = new ReffedClientGraph(
1543           handle_, opts, std::move(client_graph), session_opts_,
1544           stats_publisher_factory_, is_partial, worker_cache,
1545           !should_delete_worker_sessions_);
1546       iter = m->insert({hash, entry}).first;
1547       VLOG(1) << "Preparing to execute new graph";
1548     }
1549     *out_rcg = iter->second;
1550     (*out_rcg)->Ref();
1551     *out_count = (*out_rcg)->get_and_increment_execution_count();
1552   }
1553   return Status::OK();
1554 }
1555 
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1556 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1557                                    RCGMap* rcg_map) {
1558   VLOG(1) << "Discarding all reffed graphs";
1559   for (auto p : *rcg_map) {
1560     ReffedClientGraph* rcg = p.second;
1561     if (to_unref) {
1562       to_unref->push_back(rcg);
1563     } else {
1564       rcg->Unref();
1565     }
1566   }
1567   rcg_map->clear();
1568 }
1569 
NewStepId(int64 graph_key)1570 uint64 MasterSession::NewStepId(int64 graph_key) {
1571   if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1572     // StepId must leave the most-significant 7 bits empty for future use.
1573     return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1574   } else {
1575     uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1576     int32 retry_count = 0;
1577     while (static_cast<int64>(step_id) == CollectiveExecutor::kInvalidId) {
1578       Notification note;
1579       Status status;
1580       env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1581           graph_key, [&status, &note](const Status& s) {
1582             status = s;
1583             note.Notify();
1584           });
1585       note.WaitForNotification();
1586       if (!status.ok()) {
1587         LOG(ERROR) << "Bad status from "
1588                       "collective_executor_mgr->RefreshStepIdSequence: "
1589                    << status << ".  Retrying.";
1590         int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1591         Env::Default()->SleepForMicroseconds(delay_micros);
1592       } else {
1593         step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1594       }
1595     }
1596     return step_id;
1597   }
1598 }
1599 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1600 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1601                                       PartialRunSetupResponse* resp) {
1602   std::vector<string> inputs, outputs, targets;
1603   for (const auto& feed : req->feed()) {
1604     inputs.push_back(feed);
1605   }
1606   for (const auto& fetch : req->fetch()) {
1607     outputs.push_back(fetch);
1608   }
1609   for (const auto& target : req->target()) {
1610     targets.push_back(target);
1611   }
1612 
1613   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1614 
1615   ReffedClientGraph* rcg = nullptr;
1616 
1617   // Prepare.
1618   BuildGraphOptions opts;
1619   BuildBuildGraphOptions(*req, &opts);
1620   int64 count = 0;
1621   TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1622 
1623   rcg->Ref();
1624   RunState* run_state =
1625       new RunState(inputs, outputs, rcg,
1626                    NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1627   {
1628     mutex_lock l(mu_);
1629     partial_runs_.emplace(
1630         std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1631   }
1632 
1633   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1634 
1635   resp->set_partial_run_handle(handle);
1636   return Status::OK();
1637 }
1638 
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1639 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1640                           MutableRunStepResponseWrapper* resp) {
1641   UpdateLastAccessTime();
1642   {
1643     mutex_lock l(mu_);
1644     if (closed_) {
1645       return errors::FailedPrecondition("Session is closed.");
1646     }
1647     ++num_running_;
1648     // Note: all code paths must eventually call MarkRunCompletion()
1649     // in order to appropriate decrement the num_running_ counter.
1650   }
1651   Status status;
1652   if (!req.partial_run_handle().empty()) {
1653     status = DoPartialRun(opts, req, resp);
1654   } else {
1655     status = DoRunWithLocalExecution(opts, req, resp);
1656   }
1657   return status;
1658 }
1659 
1660 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1661 void MasterSession::MarkRunCompletion() {
1662   mutex_lock l(mu_);
1663   --num_running_;
1664   if (num_running_ == 0) {
1665     num_running_is_zero_.notify_all();
1666   }
1667 }
1668 
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1669 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1670   // Registers subgraphs if haven't done so.
1671   PartitionOptions popts;
1672   popts.node_to_loc = SplitByWorker;
1673   // The closures popts.{new_name,get_incarnation} are called synchronously in
1674   // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1675   // "this" alive during the closure.
1676   popts.new_name = [this](const string& prefix) {
1677     mutex_lock l(mu_);
1678     return strings::StrCat(prefix, "_S", next_node_id_++);
1679   };
1680   popts.get_incarnation = [this](const string& name) -> int64 {
1681     Device* d = devices_->FindDeviceByName(name);
1682     if (d == nullptr) {
1683       return PartitionOptions::kIllegalIncarnation;
1684     } else {
1685       return d->attributes().incarnation();
1686     }
1687   };
1688   popts.control_flow_added = false;
1689   const bool enable_bfloat16_sendrecv =
1690       session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1691   popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1692     if (e->IsControlEdge()) {
1693       return DT_FLOAT;
1694     }
1695     DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1696     if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1697       return DT_BFLOAT16;
1698     } else {
1699       return dtype;
1700     }
1701   };
1702   if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1703     popts.scheduling_for_recvs = true;
1704     popts.need_to_record_start_times = true;
1705   }
1706 
1707   TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1708 
1709   return Status::OK();
1710 }
1711 
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1712 Status MasterSession::DoPartialRun(CallOptions* opts,
1713                                    const RunStepRequestWrapper& req,
1714                                    MutableRunStepResponseWrapper* resp) {
1715   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1716   const string& prun_handle = req.partial_run_handle();
1717   RunState* run_state = nullptr;
1718   {
1719     mutex_lock l(mu_);
1720     auto it = partial_runs_.find(prun_handle);
1721     if (it == partial_runs_.end()) {
1722       return errors::InvalidArgument(
1723           "Must run PartialRunSetup before performing partial runs");
1724     }
1725     run_state = it->second.get();
1726   }
1727   // CollectiveOps are not supported in partial runs.
1728   if (req.options().experimental().collective_graph_key() !=
1729       BuildGraphOptions::kNoCollectiveGraphKey) {
1730     return errors::InvalidArgument(
1731         "PartialRun does not support Collective ops.  collective_graph_key "
1732         "must be kNoCollectiveGraphKey.");
1733   }
1734 
1735   // If this is the first partial run, initialize the PerStepState.
1736   if (!run_state->step_started) {
1737     run_state->step_started = true;
1738     PerStepState pss;
1739 
1740     const auto count = run_state->count;
1741     pss.collect_timeline =
1742         req.options().trace_level() == RunOptions::FULL_TRACE;
1743     pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1744     pss.report_tensor_allocations_upon_oom =
1745         req.options().report_tensor_allocations_upon_oom();
1746 
1747     // Build the cost model every 'build_cost_model_every' steps after skipping
1748     // an
1749     // initial 'build_cost_model_after' steps.
1750     const int64 build_cost_model_after =
1751         session_opts_.config.graph_options().build_cost_model_after();
1752     const int64 build_cost_model_every =
1753         session_opts_.config.graph_options().build_cost_model();
1754     pss.collect_costs =
1755         build_cost_model_every > 0 &&
1756         ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1757     pss.collect_partition_graphs = req.options().output_partition_graphs();
1758 
1759     std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1760         run_state->step_id, count, req.options());
1761     if (ph) {
1762       pss.collect_timeline = true;
1763       pss.collect_rpcs = ph->should_collect_rpcs();
1764     }
1765 
1766     run_state->pss = std::move(pss);
1767     run_state->ph = std::move(ph);
1768   }
1769 
1770   // Make sure that this is a new set of feeds that are still pending.
1771   for (size_t i = 0; i < req.num_feeds(); ++i) {
1772     const string& feed = req.feed_name(i);
1773     auto it = run_state->pending_inputs.find(feed);
1774     if (it == run_state->pending_inputs.end()) {
1775       return errors::InvalidArgument(
1776           "The feed ", feed, " was not specified in partial_run_setup.");
1777     } else if (it->second) {
1778       return errors::InvalidArgument("The feed ", feed,
1779                                      " has already been fed.");
1780     }
1781   }
1782   // Check that this is a new set of fetches that are still pending.
1783   for (size_t i = 0; i < req.num_fetches(); ++i) {
1784     const string& fetch = req.fetch_name(i);
1785     auto it = run_state->pending_outputs.find(fetch);
1786     if (it == run_state->pending_outputs.end()) {
1787       return errors::InvalidArgument(
1788           "The fetch ", fetch, " was not specified in partial_run_setup.");
1789     } else if (it->second) {
1790       return errors::InvalidArgument("The fetch ", fetch,
1791                                      " has already been fetched.");
1792     }
1793   }
1794 
1795   // Ensure that the requested fetches can be computed from the provided feeds.
1796   {
1797     mutex_lock l(mu_);
1798     TF_RETURN_IF_ERROR(
1799         run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1800   }
1801 
1802   // Determine if this partial run satisfies all the pending inputs and outputs.
1803   for (size_t i = 0; i < req.num_feeds(); ++i) {
1804     auto it = run_state->pending_inputs.find(req.feed_name(i));
1805     it->second = true;
1806   }
1807   for (size_t i = 0; i < req.num_fetches(); ++i) {
1808     auto it = run_state->pending_outputs.find(req.fetch_name(i));
1809     it->second = true;
1810   }
1811   bool is_last_partial_run = run_state->PendingDone();
1812 
1813   Status s = run_state->rcg->RunPartitions(
1814       env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1815       resp, &cancellation_manager_, is_last_partial_run);
1816 
1817   // Delete the run state if there is an error or all fetches are done.
1818   if (!s.ok() || is_last_partial_run) {
1819     ReffedClientGraph* rcg = run_state->rcg;
1820     run_state->pss.end_micros = Env::Default()->NowMicros();
1821     // Schedule post-processing and cleanup to be done asynchronously.
1822     Ref();
1823     rcg->Ref();
1824     rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1825                       req.options(), resp->mutable_metadata());
1826     cleanup.release();  // MarkRunCompletion called in done closure.
1827     rcg->CleanupPartitionsAsync(
1828         run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1829           if (!s.ok()) {
1830             LOG(ERROR) << "Cleanup partition error: " << s;
1831           }
1832           rcg->Unref();
1833           MarkRunCompletion();
1834           Unref();
1835         });
1836     mutex_lock l(mu_);
1837     partial_runs_.erase(prun_handle);
1838   }
1839   return s;
1840 }
1841 
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64 rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1842 Status MasterSession::CreateDebuggerState(
1843     const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1844     int64 rcg_execution_count,
1845     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1846   TF_RETURN_IF_ERROR(
1847       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1848 
1849   std::vector<string> input_names;
1850   for (size_t i = 0; i < req.num_feeds(); ++i) {
1851     input_names.push_back(req.feed_name(i));
1852   }
1853   std::vector<string> output_names;
1854   for (size_t i = 0; i < req.num_fetches(); ++i) {
1855     output_names.push_back(req.fetch_name(i));
1856   }
1857   std::vector<string> target_names;
1858   for (size_t i = 0; i < req.num_targets(); ++i) {
1859     target_names.push_back(req.target_name(i));
1860   }
1861 
1862   // TODO(cais): We currently use -1 as a dummy value for session run count.
1863   // While this counter value is straightforward to define and obtain for
1864   // DirectSessions, it is less so for non-direct Sessions. Devise a better
1865   // way to get its value when the need arises.
1866   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1867       debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1868       input_names, output_names, target_names));
1869 
1870   return Status::OK();
1871 }
1872 
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64 count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1873 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1874                                      const RunOptions& run_options,
1875                                      uint64 step_id, int64 count,
1876                                      PerStepState* out_pss,
1877                                      std::unique_ptr<ProfileHandler>* out_ph) {
1878   out_pss->collect_timeline =
1879       run_options.trace_level() == RunOptions::FULL_TRACE;
1880   out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1881   out_pss->report_tensor_allocations_upon_oom =
1882       run_options.report_tensor_allocations_upon_oom();
1883   // Build the cost model every 'build_cost_model_every' steps after skipping an
1884   // initial 'build_cost_model_after' steps.
1885   const int64 build_cost_model_after =
1886       session_opts_.config.graph_options().build_cost_model_after();
1887   const int64 build_cost_model_every =
1888       session_opts_.config.graph_options().build_cost_model();
1889   out_pss->collect_costs =
1890       build_cost_model_every > 0 &&
1891       ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1892   out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1893 
1894   *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1895   if (*out_ph) {
1896     out_pss->collect_timeline = true;
1897     out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1898   }
1899 }
1900 
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1901 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1902                                      uint64 step_id,
1903                                      const RunOptions& run_options,
1904                                      PerStepState* pss,
1905                                      const std::unique_ptr<ProfileHandler>& ph,
1906                                      const Status& run_status,
1907                                      RunMetadata* out_run_metadata) {
1908   Status s = run_status;
1909   if (s.ok()) {
1910     pss->end_micros = Env::Default()->NowMicros();
1911     if (rcg->collective_graph_key() !=
1912         BuildGraphOptions::kNoCollectiveGraphKey) {
1913       env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1914                                                   step_id);
1915     }
1916     // Schedule post-processing and cleanup to be done asynchronously.
1917     rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1918   } else if (errors::IsCancelled(s)) {
1919     mutex_lock l(mu_);
1920     if (closed_) {
1921       if (garbage_collected_) {
1922         s = errors::Cancelled(
1923             "Step was cancelled because the session was garbage collected due "
1924             "to inactivity.");
1925       } else {
1926         s = errors::Cancelled(
1927             "Step was cancelled by an explicit call to `Session::Close()`.");
1928       }
1929     }
1930   }
1931   Ref();
1932   rcg->Ref();
1933   rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1934     if (!s.ok()) {
1935       LOG(ERROR) << "Cleanup partition error: " << s;
1936     }
1937     rcg->Unref();
1938     MarkRunCompletion();
1939     Unref();
1940   });
1941   return s;
1942 }
1943 
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1944 Status MasterSession::DoRunWithLocalExecution(
1945     CallOptions* opts, const RunStepRequestWrapper& req,
1946     MutableRunStepResponseWrapper* resp) {
1947   VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1948   PerStepState pss;
1949   pss.start_micros = Env::Default()->NowMicros();
1950   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1951 
1952   // Prepare.
1953   BuildGraphOptions bgopts;
1954   BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1955   ReffedClientGraph* rcg = nullptr;
1956   int64 count;
1957   TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1958 
1959   // Unref "rcg" when out of scope.
1960   core::ScopedUnref unref(rcg);
1961 
1962   std::unique_ptr<DebuggerStateInterface> debugger_state;
1963   const DebugOptions& debug_options = req.options().debug_options();
1964 
1965   if (!debug_options.debug_tensor_watch_opts().empty()) {
1966     TF_RETURN_IF_ERROR(
1967         CreateDebuggerState(debug_options, req, count, &debugger_state));
1968   }
1969   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1970 
1971   // Keeps the highest 8 bits 0x01: we reserve some bits of the
1972   // step_id for future use.
1973   uint64 step_id = NewStepId(rcg->collective_graph_key());
1974   TRACEPRINTF("stepid %llu", step_id);
1975 
1976   std::unique_ptr<ProfileHandler> ph;
1977   FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1978 
1979   if (pss.collect_partition_graphs &&
1980       session_opts_.config.experimental().disable_output_partition_graphs()) {
1981     return errors::InvalidArgument(
1982         "RunOptions.output_partition_graphs() is not supported when "
1983         "disable_output_partition_graphs is true.");
1984   }
1985 
1986   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1987                                 &cancellation_manager_, false);
1988 
1989   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
1990   return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1991                         resp->mutable_metadata());
1992 }
1993 
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1994 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1995                                    MakeCallableResponse* resp) {
1996   UpdateLastAccessTime();
1997 
1998   BuildGraphOptions opts;
1999   opts.callable_options = req.options();
2000   opts.use_function_convention = false;
2001 
2002   ReffedClientGraph* callable;
2003 
2004   {
2005     mutex_lock l(mu_);
2006     if (closed_) {
2007       return errors::FailedPrecondition("Session is closed.");
2008     }
2009     std::unique_ptr<ClientGraph> client_graph;
2010     TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2011     callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2012                                      session_opts_, stats_publisher_factory_,
2013                                      false /* is_partial */, get_worker_cache(),
2014                                      !should_delete_worker_sessions_);
2015   }
2016 
2017   Status s = BuildAndRegisterPartitions(callable);
2018   if (!s.ok()) {
2019     callable->Unref();
2020     return s;
2021   }
2022 
2023   uint64 handle;
2024   {
2025     mutex_lock l(mu_);
2026     handle = next_callable_handle_++;
2027     callables_[handle] = callable;
2028   }
2029 
2030   resp->set_handle(handle);
2031   return Status::OK();
2032 }
2033 
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)2034 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2035                                     const RunCallableRequest& req,
2036                                     RunCallableResponse* resp) {
2037   VLOG(2) << "DoRunCallable req: " << req.DebugString();
2038   PerStepState pss;
2039   pss.start_micros = Env::Default()->NowMicros();
2040   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2041 
2042   // Prepare.
2043   int64 count = rcg->get_and_increment_execution_count();
2044 
2045   const uint64 step_id = NewStepId(rcg->collective_graph_key());
2046   TRACEPRINTF("stepid %llu", step_id);
2047 
2048   const RunOptions& run_options = rcg->callable_options().run_options();
2049 
2050   if (run_options.timeout_in_ms() != 0) {
2051     opts->SetTimeout(run_options.timeout_in_ms());
2052   }
2053 
2054   std::unique_ptr<ProfileHandler> ph;
2055   FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2056   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2057                                 &cancellation_manager_);
2058   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
2059   return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2060                         resp->mutable_metadata());
2061 }
2062 
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)2063 Status MasterSession::RunCallable(CallOptions* opts,
2064                                   const RunCallableRequest& req,
2065                                   RunCallableResponse* resp) {
2066   UpdateLastAccessTime();
2067   ReffedClientGraph* callable;
2068   {
2069     mutex_lock l(mu_);
2070     if (closed_) {
2071       return errors::FailedPrecondition("Session is closed.");
2072     }
2073     int64 handle = req.handle();
2074     if (handle >= next_callable_handle_) {
2075       return errors::InvalidArgument("No such callable handle: ", handle);
2076     }
2077     auto iter = callables_.find(req.handle());
2078     if (iter == callables_.end()) {
2079       return errors::InvalidArgument(
2080           "Attempted to run callable after handle was released: ", handle);
2081     }
2082     callable = iter->second;
2083     callable->Ref();
2084     ++num_running_;
2085   }
2086   core::ScopedUnref unref_callable(callable);
2087   return DoRunCallable(opts, callable, req, resp);
2088 }
2089 
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)2090 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2091                                       ReleaseCallableResponse* resp) {
2092   UpdateLastAccessTime();
2093   ReffedClientGraph* to_unref = nullptr;
2094   {
2095     mutex_lock l(mu_);
2096     auto iter = callables_.find(req.handle());
2097     if (iter != callables_.end()) {
2098       to_unref = iter->second;
2099       callables_.erase(iter);
2100     }
2101   }
2102   if (to_unref != nullptr) {
2103     to_unref->Unref();
2104   }
2105   return Status::OK();
2106 }
2107 
Close()2108 Status MasterSession::Close() {
2109   {
2110     mutex_lock l(mu_);
2111     closed_ = true;  // All subsequent calls to Run() or Extend() will fail.
2112   }
2113   cancellation_manager_.StartCancel();
2114   std::vector<ReffedClientGraph*> to_unref;
2115   {
2116     mutex_lock l(mu_);
2117     while (num_running_ != 0) {
2118       num_running_is_zero_.wait(l);
2119     }
2120     ClearRunsTable(&to_unref, &run_graphs_);
2121     ClearRunsTable(&to_unref, &partial_run_graphs_);
2122     ClearRunsTable(&to_unref, &callables_);
2123   }
2124   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2125   if (should_delete_worker_sessions_) {
2126     Status s = DeleteWorkerSessions();
2127     if (!s.ok()) {
2128       LOG(WARNING) << s;
2129     }
2130   }
2131   return Status::OK();
2132 }
2133 
GarbageCollect()2134 void MasterSession::GarbageCollect() {
2135   {
2136     mutex_lock l(mu_);
2137     closed_ = true;
2138     garbage_collected_ = true;
2139   }
2140   cancellation_manager_.StartCancel();
2141   Unref();
2142 }
2143 
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64 count)2144 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2145                                   const std::vector<string>& output_names,
2146                                   ReffedClientGraph* rcg, const uint64 step_id,
2147                                   const int64 count)
2148     : rcg(rcg), step_id(step_id), count(count) {
2149   // Initially all the feeds and fetches are pending.
2150   for (auto& name : input_names) {
2151     pending_inputs[name] = false;
2152   }
2153   for (auto& name : output_names) {
2154     pending_outputs[name] = false;
2155   }
2156 }
2157 
~RunState()2158 MasterSession::RunState::~RunState() {
2159   if (rcg) rcg->Unref();
2160 }
2161 
PendingDone() const2162 bool MasterSession::RunState::PendingDone() const {
2163   for (const auto& it : pending_inputs) {
2164     if (!it.second) return false;
2165   }
2166   for (const auto& it : pending_outputs) {
2167     if (!it.second) return false;
2168   }
2169   return true;
2170 }
2171 
2172 }  // end namespace tensorflow
2173