• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/request_id.h"
28 #include "tensorflow/core/distributed_runtime/scheduler.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/framework/allocation_description.pb.h"
32 #include "tensorflow/core/framework/collective.h"
33 #include "tensorflow/core/framework/cost_graph.pb.h"
34 #include "tensorflow/core/framework/graph_def_util.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/framework/tensor_description.pb.h"
40 #include "tensorflow/core/graph/graph_partition.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/blocking_counter.h"
43 #include "tensorflow/core/lib/core/notification.h"
44 #include "tensorflow/core/lib/core/refcount.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/random/random.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/lib/strings/stringprintf.h"
54 #include "tensorflow/core/platform/env.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/mutex.h"
58 #include "tensorflow/core/platform/tracing.h"
59 #include "tensorflow/core/public/session_options.h"
60 #include "tensorflow/core/util/device_name_utils.h"
61 
62 namespace tensorflow {
63 
64 // MasterSession wraps ClientGraph in a reference counted object.
65 // This way, MasterSession can clear up the cache mapping Run requests to
66 // compiled graphs while the compiled graph is still being used.
67 //
68 // TODO(zhifengc): Cleanup this class. It's becoming messy.
69 class MasterSession::ReffedClientGraph : public core::RefCounted {
70  public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)71   ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
72                     std::unique_ptr<ClientGraph> client_graph,
73                     const SessionOptions& session_opts,
74                     const StatsPublisherFactory& stats_publisher_factory,
75                     bool is_partial, WorkerCacheInterface* worker_cache,
76                     bool should_deregister)
77       : session_handle_(handle),
78         bg_opts_(bopts),
79         client_graph_before_register_(std::move(client_graph)),
80         session_opts_(session_opts),
81         is_partial_(is_partial),
82         callable_opts_(bopts.callable_options),
83         worker_cache_(worker_cache),
84         should_deregister_(should_deregister),
85         collective_graph_key_(
86             client_graph_before_register_->collective_graph_key) {
87     VLOG(1) << "Created ReffedClientGraph for node with "
88             << client_graph_before_register_->graph.num_node_ids();
89 
90     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
91 
92     // Initialize a name to node map for processing device stats.
93     for (Node* n : client_graph_before_register_->graph.nodes()) {
94       name_to_node_details_.emplace(
95           n->name(),
96           NodeDetails(n->type_string(),
97                       strings::StrCat(
98                           "(", absl::StrJoin(n->requested_inputs(), ", "))));
99     }
100   }
101 
~ReffedClientGraph()102   ~ReffedClientGraph() override {
103     if (should_deregister_) {
104       DeregisterPartitions();
105     } else {
106       for (Part& part : partitions_) {
107         worker_cache_->ReleaseWorker(part.name, part.worker);
108       }
109     }
110   }
111 
callable_options()112   const CallableOptions& callable_options() { return callable_opts_; }
113 
build_graph_options()114   const BuildGraphOptions& build_graph_options() { return bg_opts_; }
115 
collective_graph_key()116   int64 collective_graph_key() { return collective_graph_key_; }
117 
GetProfileHandler(uint64 step,int64_t execution_count,const RunOptions & ropts)118   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
119                                                     int64_t execution_count,
120                                                     const RunOptions& ropts) {
121     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
122   }
123 
get_and_increment_execution_count()124   int64 get_and_increment_execution_count() {
125     return execution_count_.fetch_add(1);
126   }
127 
128   // Turn RPC logging on or off, both at the WorkerCache used by this
129   // master process, and at each remote worker in use for the current
130   // partitions.
SetRPCLogging(bool active)131   void SetRPCLogging(bool active) {
132     worker_cache_->SetLogging(active);
133     // Logging is a best-effort activity, so we make async calls to turn
134     // it on/off and don't make use of the responses.
135     for (auto& p : partitions_) {
136       LoggingRequest* req = new LoggingRequest;
137       if (active) {
138         req->set_enable_rpc_logging(true);
139       } else {
140         req->set_disable_rpc_logging(true);
141       }
142       LoggingResponse* resp = new LoggingResponse;
143       Ref();
144       p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
145         delete req;
146         delete resp;
147         // ReffedClientGraph owns p.worker so we need to hold a ref to
148         // ensure that the method doesn't attempt to access p.worker after
149         // ReffedClient graph has deleted it.
150         // TODO(suharshs): Simplify this ownership model.
151         Unref();
152       });
153     }
154   }
155 
156   // Retrieve all RPC logs data accumulated for the current step, both
157   // from the local WorkerCache in use by this master process and from
158   // all the remote workers executing the remote partitions.
RetrieveLogs(int64_t step_id,StepStats * ss)159   void RetrieveLogs(int64_t step_id, StepStats* ss) {
160     // Get the local data first, because it sets *ss without merging.
161     worker_cache_->RetrieveLogs(step_id, ss);
162 
163     // Then merge in data from all the remote workers.
164     LoggingRequest req;
165     req.add_fetch_step_id(step_id);
166     int waiting_for = partitions_.size();
167     if (waiting_for > 0) {
168       mutex scoped_mu;
169       BlockingCounter all_done(waiting_for);
170       for (auto& p : partitions_) {
171         LoggingResponse* resp = new LoggingResponse;
172         p.worker->LoggingAsync(
173             &req, resp,
174             [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
175               {
176                 mutex_lock l(scoped_mu);
177                 if (s.ok()) {
178                   for (auto& lss : resp->step()) {
179                     if (step_id != lss.step_id()) {
180                       LOG(ERROR) << "Wrong step_id in LoggingResponse";
181                       continue;
182                     }
183                     ss->MergeFrom(lss.step_stats());
184                   }
185                 }
186                 delete resp;
187               }
188               // Must not decrement all_done until out of critical section where
189               // *ss is updated.
190               all_done.DecrementCount();
191             });
192       }
193       all_done.Wait();
194     }
195   }
196 
197   // Local execution methods.
198 
199   // Partitions the graph into subgraphs and registers them on
200   // workers.
201   Status RegisterPartitions(PartitionOptions popts);
202 
203   // Runs one step of all partitions.
204   Status RunPartitions(const MasterEnv* env, int64_t step_id,
205                        int64_t execution_count, PerStepState* pss,
206                        CallOptions* opts, const RunStepRequestWrapper& req,
207                        MutableRunStepResponseWrapper* resp,
208                        CancellationManager* cm, const bool is_last_partial_run);
209   Status RunPartitions(const MasterEnv* env, int64_t step_id,
210                        int64_t execution_count, PerStepState* pss,
211                        CallOptions* call_opts, const RunCallableRequest& req,
212                        RunCallableResponse* resp, CancellationManager* cm);
213 
214   // Calls workers to cleanup states for the step "step_id".  Calls
215   // `done` when all cleanup RPCs have completed.
216   void CleanupPartitionsAsync(int64_t step_id, StatusCallback done);
217 
218   // Post-processing of any runtime statistics gathered during execution.
219   void ProcessStats(int64_t step_id, PerStepState* pss, ProfileHandler* ph,
220                     const RunOptions& options, RunMetadata* resp);
221   void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
222                           bool is_rpc);
223   // Checks that the requested fetches can be computed from the provided feeds.
224   Status CheckFetches(const RunStepRequestWrapper& req,
225                       const RunState* run_state,
226                       GraphExecutionState* execution_state);
227 
228  private:
229   const string session_handle_;
230   const BuildGraphOptions bg_opts_;
231 
232   // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
233   std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
234   const SessionOptions session_opts_;
235   const bool is_partial_;
236   const CallableOptions callable_opts_;
237   WorkerCacheInterface* const worker_cache_;  // Not owned.
238 
239   struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails240     explicit NodeDetails(string type_string, string detail_text)
241         : type_string(std::move(type_string)),
242           detail_text(std::move(detail_text)) {}
243     const string type_string;
244     const string detail_text;
245   };
246   std::unordered_map<string, NodeDetails> name_to_node_details_;
247 
248   const bool should_deregister_;
249   const int64 collective_graph_key_;
250   std::atomic<int64> execution_count_ = {0};
251 
252   // Graph partitioned into per-location subgraphs.
253   struct Part {
254     // Worker name.
255     string name;
256 
257     // Maps feed names to rendezvous keys. Empty most of the time.
258     std::unordered_map<string, string> feed_key;
259 
260     // Maps rendezvous keys to fetch names. Empty most of the time.
261     std::unordered_map<string, string> key_fetch;
262 
263     // The interface to the worker. Owned.
264     WorkerInterface* worker = nullptr;
265 
266     // After registration with the worker, graph_handle identifies
267     // this partition on the worker.
268     string graph_handle;
269 
Parttensorflow::MasterSession::ReffedClientGraph::Part270     Part() : feed_key(3), key_fetch(3) {}
271   };
272 
273   // partitions_ is immutable after RegisterPartitions() call
274   // finishes.  RunPartitions() can access partitions_ safely without
275   // acquiring locks.
276   std::vector<Part> partitions_;
277 
278   mutable mutex mu_;
279 
280   // Partition initialization and registration only needs to happen
281   // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
282   // indicates the initialization is ongoing.
283   Notification init_done_;
284 
285   // init_result_ remembers the initialization error if any.
286   Status init_result_ TF_GUARDED_BY(mu_);
287 
288   std::unique_ptr<StatsPublisherInterface> stats_publisher_;
289 
DetailText(const NodeDetails & details,const NodeExecStats & stats)290   string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
291     int64_t tot = 0;
292     for (auto& no : stats.output()) {
293       tot += no.tensor_description().allocation_description().requested_bytes();
294     }
295     string bytes;
296     if (tot >= 0.1 * 1048576.0) {
297       bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
298     }
299     return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
300                            details.detail_text);
301   }
302 
303   // Send/Recv nodes that are the result of client-added
304   // feeds and fetches must be tracked so that the tensors
305   // can be added to the local rendezvous.
306   static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
307                                    const PartitionOptions& popts);
308 
309   // The actual graph partitioning and registration implementation.
310   Status DoBuildPartitions(
311       PartitionOptions popts, ClientGraph* client_graph,
312       std::unordered_map<string, GraphDef>* out_partitions);
313   Status DoRegisterPartitions(
314       const PartitionOptions& popts,
315       std::unordered_map<string, GraphDef> graph_partitions);
316 
317   // Prepares a number of calls to workers. One call per partition.
318   // This is a generic method that handles Run, PartialRun, and RunCallable.
319   template <class FetchListType, class ClientRequestType,
320             class ClientResponseType>
321   Status RunPartitionsHelper(
322       const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
323       const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
324       int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
325       const ClientRequestType& req, ClientResponseType* resp,
326       CancellationManager* cm, bool is_last_partial_run);
327 
328   // Deregisters the partitions on the workers.  Called in the
329   // destructor and does not wait for the rpc completion.
330   void DeregisterPartitions();
331 
332   TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
333 };
334 
RegisterPartitions(PartitionOptions popts)335 Status MasterSession::ReffedClientGraph::RegisterPartitions(
336     PartitionOptions popts) {
337   {  // Ensure register once.
338     mu_.lock();
339     if (client_graph_before_register_) {
340       // The `ClientGraph` is no longer needed after partitions are registered.
341       // Since it can account for a large amount of memory, we consume it here,
342       // and it will be freed after concluding with registration.
343 
344       std::unique_ptr<ClientGraph> client_graph;
345       std::swap(client_graph_before_register_, client_graph);
346       mu_.unlock();
347       std::unordered_map<string, GraphDef> graph_defs;
348       popts.flib_def = client_graph->flib_def.get();
349       Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
350       if (s.ok()) {
351         // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
352         // valid after the call to DoRegisterPartitions begins, so
353         // `stats_publisher_` must make a copy if it wants to retain the
354         // GraphDef objects.
355         std::vector<const GraphDef*> graph_defs_for_publishing;
356         graph_defs_for_publishing.reserve(partitions_.size());
357         for (const auto& name_def : graph_defs) {
358           graph_defs_for_publishing.push_back(&name_def.second);
359         }
360         stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
361         s = DoRegisterPartitions(popts, std::move(graph_defs));
362       }
363       mu_.lock();
364       init_result_ = s;
365       init_done_.Notify();
366     } else {
367       mu_.unlock();
368       init_done_.WaitForNotification();
369       mu_.lock();
370     }
371     const Status result = init_result_;
372     mu_.unlock();
373     return result;
374   }
375 }
376 
SplitByWorker(const Node * node)377 static string SplitByWorker(const Node* node) {
378   string task;
379   string device;
380   CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
381                                          &device))
382       << "node: " << node->name() << " dev: " << node->assigned_device_name();
383   return task;
384 }
385 
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)386 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
387     Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
388   for (int i = 0; i < graph_def.node_size(); ++i) {
389     const NodeDef& ndef = graph_def.node(i);
390     const bool is_recv = ndef.op() == "_Recv";
391     const bool is_send = ndef.op() == "_Send";
392 
393     if (is_recv || is_send) {
394       // Only send/recv nodes that were added as feeds and fetches
395       // (client-terminated) should be tracked.  Other send/recv nodes
396       // are for transferring data between partitions / memory spaces.
397       bool client_terminated;
398       TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
399       if (client_terminated) {
400         string name;
401         TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
402         string send_device;
403         TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
404         string recv_device;
405         TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
406         uint64 send_device_incarnation;
407         TF_CHECK_OK(
408             GetNodeAttr(ndef, "send_device_incarnation",
409                         reinterpret_cast<int64*>(&send_device_incarnation)));
410         const string& key =
411             Rendezvous::CreateKey(send_device, send_device_incarnation,
412                                   recv_device, name, FrameAndIter(0, 0));
413 
414         if (is_recv) {
415           part->feed_key.insert({name, key});
416         } else {
417           part->key_fetch.insert({key, name});
418         }
419       }
420     }
421   }
422 }
423 
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)424 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
425     PartitionOptions popts, ClientGraph* client_graph,
426     std::unordered_map<string, GraphDef>* out_partitions) {
427   if (popts.need_to_record_start_times) {
428     CostModel cost_model(true);
429     cost_model.InitFromGraph(client_graph->graph);
430     // TODO(yuanbyu): Use the real cost model.
431     // execution_state_->MergeFromGlobal(&cost_model);
432     SlackAnalysis sa(&client_graph->graph, &cost_model);
433     sa.ComputeAsap(&popts.start_times);
434   }
435 
436   // Partition the graph.
437   return Partition(popts, &client_graph->graph, out_partitions);
438 }
439 
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)440 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
441     const PartitionOptions& popts,
442     std::unordered_map<string, GraphDef> graph_partitions) {
443   partitions_.reserve(graph_partitions.size());
444   Status s;
445   for (auto& name_def : graph_partitions) {
446     partitions_.emplace_back();
447     Part* part = &partitions_.back();
448     part->name = name_def.first;
449     TrackFeedsAndFetches(part, name_def.second, popts);
450     part->worker = worker_cache_->GetOrCreateWorker(part->name);
451     if (part->worker == nullptr) {
452       s = errors::NotFound("worker ", part->name);
453       break;
454     }
455   }
456   if (!s.ok()) {
457     for (Part& part : partitions_) {
458       worker_cache_->ReleaseWorker(part.name, part.worker);
459       part.worker = nullptr;
460     }
461     return s;
462   }
463   struct Call {
464     RegisterGraphRequest req;
465     RegisterGraphResponse resp;
466     Status status;
467   };
468   const int num = partitions_.size();
469   gtl::InlinedVector<Call, 4> calls(num);
470   BlockingCounter done(num);
471   for (int i = 0; i < num; ++i) {
472     const Part& part = partitions_[i];
473     Call* c = &calls[i];
474     c->req.set_session_handle(session_handle_);
475     c->req.set_create_worker_session_called(!should_deregister_);
476     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
477     StripDefaultAttributes(*OpRegistry::Global(),
478                            c->req.mutable_graph_def()->mutable_node());
479     *c->req.mutable_config_proto() = session_opts_.config;
480     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
481     *c->req.mutable_debug_options() =
482         callable_opts_.run_options().debug_options();
483     c->req.set_collective_graph_key(collective_graph_key_);
484     VLOG(2) << "Register " << c->req.graph_def().DebugString();
485     auto cb = [c, &done](const Status& s) {
486       c->status = s;
487       done.DecrementCount();
488     };
489     part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
490   }
491   done.Wait();
492   for (int i = 0; i < num; ++i) {
493     Call* c = &calls[i];
494     s.Update(c->status);
495     partitions_[i].graph_handle = c->resp.graph_handle();
496   }
497   return s;
498 }
499 
500 namespace {
501 // Helper class to manage "num" parallel RunGraph calls.
502 class RunManyGraphs {
503  public:
RunManyGraphs(int num)504   explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
505 
~RunManyGraphs()506   ~RunManyGraphs() {}
507 
508   // Returns the index-th call.
509   struct Call {
510     CallOptions opts;
511     const string* worker_name;
512     std::atomic<bool> done{false};
513     std::unique_ptr<MutableRunGraphRequestWrapper> req;
514     std::unique_ptr<MutableRunGraphResponseWrapper> resp;
515   };
get(int index)516   Call* get(int index) { return &calls_[index]; }
517 
518   // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)519   void WhenDone(int index, const Status& s) {
520     TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
521     Call* call = get(index);
522     call->done = true;
523     auto resp = call->resp.get();
524     if (resp->status_code() != error::Code::OK) {
525       // resp->status_code will only be non-OK if s.ok().
526       mutex_lock l(mu_);
527       ReportBadStatus(Status(resp->status_code(),
528                              strings::StrCat("From ", *call->worker_name, ":\n",
529                                              resp->status_error_message())));
530     } else if (!s.ok()) {
531       mutex_lock l(mu_);
532       ReportBadStatus(
533           Status(s.code(), strings::StrCat("From ", *call->worker_name, ":\n",
534                                            s.error_message())));
535     }
536     pending_.DecrementCount();
537   }
538 
StartCancel()539   void StartCancel() {
540     mutex_lock l(mu_);
541     ReportBadStatus(errors::Cancelled("RunManyGraphs"));
542   }
543 
Wait()544   void Wait() {
545     // Check the error status every 60 seconds in other to print a log message
546     // in the event of a hang.
547     const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
548     while (true) {
549       if (pending_.WaitFor(kCheckErrorPeriod)) {
550         return;
551       }
552       if (!status().ok()) {
553         break;
554       }
555     }
556 
557     // The step has failed. Wait for another 60 seconds before diagnosing a
558     // hang.
559     DCHECK(!status().ok());
560     if (pending_.WaitFor(kCheckErrorPeriod)) {
561       return;
562     }
563     LOG(ERROR)
564         << "RunStep still blocked after 60 seconds. Failed with error status: "
565         << status();
566     for (const Call& call : calls_) {
567       if (!call.done) {
568         LOG(ERROR) << "- No response from RunGraph call to worker: "
569                    << *call.worker_name;
570       }
571     }
572     pending_.Wait();
573   }
574 
status() const575   Status status() const {
576     mutex_lock l(mu_);
577     // Concat status objects in this StatusGroup to get the aggregated status,
578     // as each status in status_group_ is already summarized status.
579     return status_group_.as_concatenated_status();
580   }
581 
582  private:
583   gtl::InlinedVector<Call, 4> calls_;
584 
585   BlockingCounter pending_;
586   mutable mutex mu_;
587   StatusGroup status_group_ TF_GUARDED_BY(mu_);
588   bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
589 
ReportBadStatus(const Status & s)590   void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
591     VLOG(1) << "Master received error status " << s;
592     if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
593       // Only start cancelling other workers upon receiving a non-derived
594       // error
595       cancel_issued_ = true;
596 
597       VLOG(1) << "Master received error report. Cancelling remaining workers.";
598       for (Call& call : calls_) {
599         call.opts.StartCancel();
600       }
601     }
602 
603     status_group_.Update(s);
604   }
605 
606   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
607 };
608 
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)609 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
610                                 MutableRunGraphRequestWrapper* worker_req,
611                                 size_t index, const string& send_key) {
612   return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
613 }
614 
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)615 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
616                                 MutableRunGraphRequestWrapper* worker_req,
617                                 size_t index, const string& send_key) {
618   return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
619 }
620 
621 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
622 // in-process messages.
623 struct RunCallableResponseWrapper {
624   RunCallableResponse* resp;  // Not owned.
625   std::unordered_map<string, TensorProto> fetch_key_to_protos;
626 
mutable_metadatatensorflow::__anonbe2b6c110411::RunCallableResponseWrapper627   RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
628 
AddTensorFromRunGraphResponsetensorflow::__anonbe2b6c110411::RunCallableResponseWrapper629   Status AddTensorFromRunGraphResponse(
630       const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
631       size_t index) {
632     return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
633   }
634 };
635 }  // namespace
636 
637 template <class FetchListType, class ClientRequestType,
638           class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)639 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
640     const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
641     const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
642     int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
643     const ClientRequestType& req, ClientResponseType* resp,
644     CancellationManager* cm, bool is_last_partial_run) {
645   // Collect execution cost stats on a smoothly decreasing frequency.
646   ExecutorOpts exec_opts;
647   if (pss->report_tensor_allocations_upon_oom) {
648     exec_opts.set_report_tensor_allocations_upon_oom(true);
649   }
650   if (pss->collect_costs) {
651     exec_opts.set_record_costs(true);
652   }
653   if (pss->collect_timeline) {
654     exec_opts.set_record_timeline(true);
655   }
656   if (pss->collect_rpcs) {
657     SetRPCLogging(true);
658   }
659   if (pss->collect_partition_graphs) {
660     exec_opts.set_record_partition_graphs(true);
661   }
662   if (pss->collect_costs || pss->collect_timeline) {
663     pss->step_stats.resize(partitions_.size());
664   }
665 
666   const int num = partitions_.size();
667   RunManyGraphs calls(num);
668 
669   for (int i = 0; i < num; ++i) {
670     const Part& part = partitions_[i];
671     RunManyGraphs::Call* c = calls.get(i);
672     c->worker_name = &part.name;
673     c->req.reset(part.worker->CreateRunGraphRequest());
674     c->resp.reset(part.worker->CreateRunGraphResponse());
675     if (is_partial_) {
676       c->req->set_is_partial(is_partial_);
677       c->req->set_is_last_partial_run(is_last_partial_run);
678     }
679     c->req->set_session_handle(session_handle_);
680     c->req->set_create_worker_session_called(!should_deregister_);
681     c->req->set_graph_handle(part.graph_handle);
682     c->req->set_step_id(step_id);
683     *c->req->mutable_exec_opts() = exec_opts;
684     c->req->set_store_errors_in_response_body(true);
685     c->req->set_request_id(GetUniqueRequestId());
686     // If any feeds are provided, send the feed values together
687     // in the RunGraph request.
688     // In the partial case, we only want to include feeds provided in the req.
689     // In the non-partial case, all feeds in the request are in the part.
690     // We keep these as separate paths for now, to ensure we aren't
691     // inadvertently slowing down the normal run path.
692     if (is_partial_) {
693       for (const auto& name_index : feeds) {
694         const auto iter = part.feed_key.find(string(name_index.first));
695         if (iter == part.feed_key.end()) {
696           // The provided feed must be for a different partition.
697           continue;
698         }
699         const string& key = iter->second;
700         TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
701                                                     name_index.second, key));
702       }
703       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
704       // For now, we just iterate through partitions to find the matching key.
705       for (const string& req_fetch : fetches) {
706         for (const auto& key_fetch : part.key_fetch) {
707           if (key_fetch.second == req_fetch) {
708             c->req->add_recv_key(key_fetch.first);
709             break;
710           }
711         }
712       }
713     } else {
714       for (const auto& feed_key : part.feed_key) {
715         const string& feed = feed_key.first;
716         const string& key = feed_key.second;
717         auto iter = feeds.find(feed);
718         if (iter == feeds.end()) {
719           return errors::Internal("No feed index found for feed: ", feed);
720         }
721         const int64_t feed_index = iter->second;
722         TF_RETURN_IF_ERROR(
723             AddSendFromClientRequest(req, c->req.get(), feed_index, key));
724       }
725       for (const auto& key_fetch : part.key_fetch) {
726         const string& key = key_fetch.first;
727         c->req->add_recv_key(key);
728       }
729     }
730   }
731 
732   // Issues RunGraph calls.
733   for (int i = 0; i < num; ++i) {
734     const Part& part = partitions_[i];
735     RunManyGraphs::Call* call = calls.get(i);
736     TRACEPRINTF("Partition %d %s", i, part.name.c_str());
737     part.worker->RunGraphAsync(
738         &call->opts, call->req.get(), call->resp.get(),
739         std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
740   }
741 
742   // Waits for the RunGraph calls.
743   call_opts->SetCancelCallback([&calls]() {
744     LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
745                  "worker operations.";
746     calls.StartCancel();
747   });
748   auto token = cm->get_cancellation_token();
749   const bool success =
750       cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
751   if (!success) {
752     calls.StartCancel();
753   }
754   calls.Wait();
755   call_opts->ClearCancelCallback();
756   if (success) {
757     cm->DeregisterCallback(token);
758   } else {
759     return errors::Cancelled("Step was cancelled");
760   }
761   TF_RETURN_IF_ERROR(calls.status());
762 
763   // Collects fetches and metadata.
764   Status status;
765   for (int i = 0; i < num; ++i) {
766     const Part& part = partitions_[i];
767     MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
768     for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
769       auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
770       if (iter == part.key_fetch.end()) {
771         status.Update(errors::Internal("Unexpected fetch key: ",
772                                        run_graph_resp->recv_key(j)));
773         break;
774       }
775       const string& fetch = iter->second;
776       status.Update(
777           resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
778       if (!status.ok()) {
779         break;
780       }
781     }
782     if (pss->collect_timeline) {
783       pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
784     }
785     if (pss->collect_costs) {
786       CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
787       for (int j = 0; j < cost_graph->node_size(); ++j) {
788         resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
789             cost_graph->mutable_node(j));
790       }
791     }
792     if (pss->collect_partition_graphs) {
793       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
794           resp->mutable_metadata()->mutable_partition_graphs();
795       for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
796         partition_graph_defs->Add()->Swap(
797             run_graph_resp->mutable_partition_graph(i));
798       }
799     }
800   }
801   return status;
802 }
803 
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)804 Status MasterSession::ReffedClientGraph::RunPartitions(
805     const MasterEnv* env, int64_t step_id, int64_t execution_count,
806     PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
807     MutableRunStepResponseWrapper* resp, CancellationManager* cm,
808     const bool is_last_partial_run) {
809   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
810           << execution_count;
811   // Maps the names of fed tensors to their index in `req`.
812   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
813   for (size_t i = 0; i < req.num_feeds(); ++i) {
814     if (!feeds.insert({req.feed_name(i), i}).second) {
815       return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
816     }
817   }
818 
819   std::vector<string> fetches;
820   fetches.reserve(req.num_fetches());
821   for (size_t i = 0; i < req.num_fetches(); ++i) {
822     fetches.push_back(req.fetch_name(i));
823   }
824 
825   return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
826                              call_opts, req, resp, cm, is_last_partial_run);
827 }
828 
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)829 Status MasterSession::ReffedClientGraph::RunPartitions(
830     const MasterEnv* env, int64_t step_id, int64_t execution_count,
831     PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
832     RunCallableResponse* resp, CancellationManager* cm) {
833   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
834           << execution_count;
835   // Maps the names of fed tensors to their index in `req`.
836   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
837   for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
838     if (!feeds.insert({callable_opts_.feed(i), i}).second) {
839       // MakeCallable will fail if there are two feeds with the same name.
840       return errors::Internal("Duplicated feeds in callable: ",
841                               callable_opts_.feed(i));
842     }
843   }
844 
845   // Create a wrapped response object to collect the fetched values and
846   // rearrange them for the RunCallableResponse.
847   RunCallableResponseWrapper wrapped_resp;
848   wrapped_resp.resp = resp;
849 
850   TF_RETURN_IF_ERROR(RunPartitionsHelper(
851       feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
852       call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
853 
854   // Collects fetches.
855   for (const string& fetch : callable_opts_.fetch()) {
856     TensorProto* fetch_proto = resp->mutable_fetch()->Add();
857     auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
858     if (iter == wrapped_resp.fetch_key_to_protos.end()) {
859       return errors::Internal("Worker did not return a value for fetch: ",
860                               fetch);
861     }
862     fetch_proto->Swap(&iter->second);
863   }
864   return Status::OK();
865 }
866 
867 namespace {
868 
869 class CleanupBroadcastHelper {
870  public:
CleanupBroadcastHelper(int64_t step_id,int num_calls,StatusCallback done)871   CleanupBroadcastHelper(int64_t step_id, int num_calls, StatusCallback done)
872       : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
873     req_.set_step_id(step_id);
874   }
875 
876   // Returns a non-owned pointer to a request buffer for all calls.
request()877   CleanupGraphRequest* request() { return &req_; }
878 
879   // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)880   CleanupGraphResponse* response(int i) { return &resps_[i]; }
881 
882   // Called when the ith response is received.
call_done(int i,const Status & s)883   void call_done(int i, const Status& s) {
884     bool run_callback = false;
885     Status status_copy;
886     {
887       mutex_lock l(mu_);
888       status_.Update(s);
889       if (--num_pending_ == 0) {
890         run_callback = true;
891         status_copy = status_;
892       }
893     }
894     if (run_callback) {
895       done_(status_copy);
896       // This is the last call, so delete the helper object.
897       delete this;
898     }
899   }
900 
901  private:
902   // A single request shared between all workers.
903   CleanupGraphRequest req_;
904   // One response buffer for each worker.
905   gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
906 
907   mutex mu_;
908   // Number of requests remaining to be collected.
909   int num_pending_ TF_GUARDED_BY(mu_);
910   // Aggregate status of the operation.
911   Status status_ TF_GUARDED_BY(mu_);
912   // Callback to be called when all operations complete.
913   StatusCallback done_;
914 
915   TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
916 };
917 
918 }  // namespace
919 
CleanupPartitionsAsync(int64_t step_id,StatusCallback done)920 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
921     int64_t step_id, StatusCallback done) {
922   const int num = partitions_.size();
923   // Helper object will be deleted when the final call completes.
924   CleanupBroadcastHelper* helper =
925       new CleanupBroadcastHelper(step_id, num, std::move(done));
926   for (int i = 0; i < num; ++i) {
927     const Part& part = partitions_[i];
928     part.worker->CleanupGraphAsync(
929         helper->request(), helper->response(i),
930         [helper, i](const Status& s) { helper->call_done(i, s); });
931   }
932 }
933 
ProcessStats(int64_t step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)934 void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id,
935                                                     PerStepState* pss,
936                                                     ProfileHandler* ph,
937                                                     const RunOptions& options,
938                                                     RunMetadata* resp) {
939   if (!pss->collect_costs && !pss->collect_timeline) return;
940 
941   // Out-of-band logging data is collected now, during post-processing.
942   if (pss->collect_timeline) {
943     SetRPCLogging(false);
944     RetrieveLogs(step_id, &pss->rpc_stats);
945   }
946   for (size_t i = 0; i < partitions_.size(); ++i) {
947     const StepStats& ss = pss->step_stats[i];
948     if (ph) {
949       for (const auto& ds : ss.dev_stats()) {
950         ProcessDeviceStats(ph, ds, false /*is_rpc*/);
951       }
952     }
953   }
954   if (ph) {
955     for (const auto& ds : pss->rpc_stats.dev_stats()) {
956       ProcessDeviceStats(ph, ds, true /*is_rpc*/);
957     }
958     ph->StepDone(pss->start_micros, pss->end_micros,
959                  Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
960                  Status::OK());
961   }
962   // Assemble all stats for this timeline into a merged StepStats.
963   if (pss->collect_timeline) {
964     StepStats step_stats_proto;
965     step_stats_proto.Swap(&pss->rpc_stats);
966     for (size_t i = 0; i < partitions_.size(); ++i) {
967       step_stats_proto.MergeFrom(pss->step_stats[i]);
968       pss->step_stats[i].Clear();
969     }
970     pss->step_stats.clear();
971     // Copy the stats back, but only for on-demand profiling to avoid slowing
972     // down calls that trigger the automatic profiling.
973     if (options.trace_level() == RunOptions::FULL_TRACE) {
974       resp->mutable_step_stats()->Swap(&step_stats_proto);
975     } else {
976       // If FULL_TRACE, it can be fetched from Session API, no need for
977       // duplicated publishing.
978       stats_publisher_->PublishStatsProto(step_stats_proto);
979     }
980   }
981 }
982 
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)983 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
984     ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
985   const string& dev_name = ds.device();
986   VLOG(1) << "Device " << dev_name << " reports stats for "
987           << ds.node_stats_size() << " nodes";
988   for (const auto& ns : ds.node_stats()) {
989     if (is_rpc) {
990       // We don't have access to a good Node pointer, so we rely on
991       // sufficient data being present in the NodeExecStats.
992       ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
993                       ns.timeline_label());
994     } else {
995       auto iter = name_to_node_details_.find(ns.node_name());
996       const bool found_node_in_graph = iter != name_to_node_details_.end();
997       if (!found_node_in_graph && ns.timeline_label().empty()) {
998         // The counter incrementing is not thread-safe. But we don't really
999         // care.
1000         // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1001         // more general usage.
1002         static int log_counter = 0;
1003         if (log_counter < 10) {
1004           log_counter++;
1005           LOG(WARNING) << "Failed to find node " << ns.node_name()
1006                        << " for dev " << dev_name;
1007         }
1008         continue;
1009       }
1010       const string& optype =
1011           found_node_in_graph ? iter->second.type_string : ns.node_name();
1012       string details;
1013       if (!ns.timeline_label().empty()) {
1014         details = ns.timeline_label();
1015       } else if (found_node_in_graph) {
1016         details = DetailText(iter->second, ns);
1017       } else {
1018         // Leave details string empty
1019       }
1020       ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1021                       details);
1022     }
1023   }
1024 }
1025 
1026 // TODO(suharshs): Merge with CheckFetches in DirectSession.
1027 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1028 // on once at setup time to prevent us from computing the dependencies
1029 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)1030 Status MasterSession::ReffedClientGraph::CheckFetches(
1031     const RunStepRequestWrapper& req, const RunState* run_state,
1032     GraphExecutionState* execution_state) {
1033   // Build the set of pending feeds that we haven't seen.
1034   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1035   for (const auto& input : run_state->pending_inputs) {
1036     // Skip if already fed.
1037     if (input.second) continue;
1038     TensorId id(ParseTensorName(input.first));
1039     const Node* n = execution_state->get_node_by_name(string(id.first));
1040     if (n == nullptr) {
1041       return errors::NotFound("Feed ", input.first, ": not found");
1042     }
1043     pending_feeds.insert(id);
1044   }
1045   for (size_t i = 0; i < req.num_feeds(); ++i) {
1046     const TensorId id(ParseTensorName(req.feed_name(i)));
1047     pending_feeds.erase(id);
1048   }
1049 
1050   // Initialize the stack with the fetch nodes.
1051   std::vector<const Node*> stack;
1052   for (size_t i = 0; i < req.num_fetches(); ++i) {
1053     const string& fetch = req.fetch_name(i);
1054     const TensorId id(ParseTensorName(fetch));
1055     const Node* n = execution_state->get_node_by_name(string(id.first));
1056     if (n == nullptr) {
1057       return errors::NotFound("Fetch ", fetch, ": not found");
1058     }
1059     stack.push_back(n);
1060   }
1061 
1062   // Any tensor needed for fetches can't be in pending_feeds.
1063   // We need to use the original full graph from execution state.
1064   const Graph* graph = execution_state->full_graph();
1065   std::vector<bool> visited(graph->num_node_ids(), false);
1066   while (!stack.empty()) {
1067     const Node* n = stack.back();
1068     stack.pop_back();
1069 
1070     for (const Edge* in_edge : n->in_edges()) {
1071       const Node* in_node = in_edge->src();
1072       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1073         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1074                                        in_edge->src_output(),
1075                                        " can't be computed from the feeds"
1076                                        " that have been fed so far.");
1077       }
1078       if (!visited[in_node->id()]) {
1079         visited[in_node->id()] = true;
1080         stack.push_back(in_node);
1081       }
1082     }
1083   }
1084   return Status::OK();
1085 }
1086 
1087 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1088 // result.
DeregisterPartitions()1089 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1090   struct Call {
1091     DeregisterGraphRequest req;
1092     DeregisterGraphResponse resp;
1093   };
1094   for (Part& part : partitions_) {
1095     // The graph handle may be empty if we failed during partition registration.
1096     if (!part.graph_handle.empty()) {
1097       Call* c = new Call;
1098       c->req.set_session_handle(session_handle_);
1099       c->req.set_create_worker_session_called(!should_deregister_);
1100       c->req.set_graph_handle(part.graph_handle);
1101       // NOTE(mrry): We must capture `worker_cache_` since `this`
1102       // could be deleted before the callback is called.
1103       WorkerCacheInterface* worker_cache = worker_cache_;
1104       const string name = part.name;
1105       WorkerInterface* w = part.worker;
1106       CHECK_NOTNULL(w);
1107       auto cb = [worker_cache, c, name, w](const Status& s) {
1108         if (!s.ok()) {
1109           // This error is potentially benign, so we don't log at the
1110           // error level.
1111           LOG(INFO) << "DeregisterGraph error: " << s;
1112         }
1113         delete c;
1114         worker_cache->ReleaseWorker(name, w);
1115       };
1116       w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1117     }
1118   }
1119 }
1120 
1121 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1122 void CopyAndSortStrings(size_t size,
1123                         const std::function<string(size_t)>& input_accessor,
1124                         protobuf::RepeatedPtrField<string>* output) {
1125   std::vector<string> temp;
1126   temp.reserve(size);
1127   for (size_t i = 0; i < size; ++i) {
1128     output->Add(input_accessor(i));
1129   }
1130   std::sort(output->begin(), output->end());
1131 }
1132 }  // namespace
1133 
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1134 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1135                             const ConfigProto& config,
1136                             BuildGraphOptions* opts) {
1137   CallableOptions* callable_opts = &opts->callable_options;
1138   CopyAndSortStrings(
1139       req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1140       callable_opts->mutable_feed());
1141   CopyAndSortStrings(
1142       req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1143       callable_opts->mutable_fetch());
1144   CopyAndSortStrings(
1145       req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1146       callable_opts->mutable_target());
1147 
1148   if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1149     *callable_opts->mutable_run_options()->mutable_debug_options() =
1150         req.options().debug_options();
1151   }
1152 
1153   opts->collective_graph_key =
1154       req.options().experimental().collective_graph_key();
1155   if (config.experimental().collective_deterministic_sequential_execution()) {
1156     opts->collective_order = GraphCollectiveOrder::kEdges;
1157   } else if (config.experimental().collective_nccl()) {
1158     opts->collective_order = GraphCollectiveOrder::kAttrs;
1159   }
1160 }
1161 
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1162 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1163                             BuildGraphOptions* opts) {
1164   CallableOptions* callable_opts = &opts->callable_options;
1165   CopyAndSortStrings(
1166       req.feed_size(), [&req](size_t i) { return req.feed(i); },
1167       callable_opts->mutable_feed());
1168   CopyAndSortStrings(
1169       req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1170       callable_opts->mutable_fetch());
1171   CopyAndSortStrings(
1172       req.target_size(), [&req](size_t i) { return req.target(i); },
1173       callable_opts->mutable_target());
1174 
1175   // TODO(cais): Add TFDBG support to partial runs.
1176 }
1177 
HashBuildGraphOptions(const BuildGraphOptions & opts)1178 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1179   uint64 h = 0x2b992ddfa23249d6ull;
1180   for (const string& name : opts.callable_options.feed()) {
1181     h = Hash64(name.c_str(), name.size(), h);
1182   }
1183   for (const string& name : opts.callable_options.target()) {
1184     h = Hash64(name.c_str(), name.size(), h);
1185   }
1186   for (const string& name : opts.callable_options.fetch()) {
1187     h = Hash64(name.c_str(), name.size(), h);
1188   }
1189 
1190   const DebugOptions& debug_options =
1191       opts.callable_options.run_options().debug_options();
1192   if (!debug_options.debug_tensor_watch_opts().empty()) {
1193     const string watch_summary =
1194         SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1195     h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1196   }
1197 
1198   return h;
1199 }
1200 
BuildGraphOptionsString(const BuildGraphOptions & opts)1201 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1202   string buf;
1203   for (const string& name : opts.callable_options.feed()) {
1204     strings::StrAppend(&buf, " FdE: ", name);
1205   }
1206   strings::StrAppend(&buf, "\n");
1207   for (const string& name : opts.callable_options.target()) {
1208     strings::StrAppend(&buf, " TN: ", name);
1209   }
1210   strings::StrAppend(&buf, "\n");
1211   for (const string& name : opts.callable_options.fetch()) {
1212     strings::StrAppend(&buf, " FeE: ", name);
1213   }
1214   if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1215     strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1216   }
1217   strings::StrAppend(&buf, "\n");
1218   return buf;
1219 }
1220 
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1221 MasterSession::MasterSession(
1222     const SessionOptions& opt, const MasterEnv* env,
1223     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1224     std::unique_ptr<WorkerCacheInterface> worker_cache,
1225     std::unique_ptr<DeviceSet> device_set,
1226     std::vector<string> filtered_worker_list,
1227     StatsPublisherFactory stats_publisher_factory)
1228     : session_opts_(opt),
1229       env_(env),
1230       handle_(strings::FpToString(random::New64())),
1231       remote_devs_(std::move(remote_devs)),
1232       worker_cache_(std::move(worker_cache)),
1233       devices_(std::move(device_set)),
1234       filtered_worker_list_(std::move(filtered_worker_list)),
1235       stats_publisher_factory_(std::move(stats_publisher_factory)),
1236       graph_version_(0),
1237       run_graphs_(5),
1238       partial_run_graphs_(5) {
1239   UpdateLastAccessTime();
1240   CHECK(devices_) << "device_set was null!";
1241 
1242   VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1243           << " #remote " << remote_devs_->size();
1244   VLOG(1) << "Start master session " << handle_
1245           << " with config: " << session_opts_.config.ShortDebugString();
1246 }
1247 
~MasterSession()1248 MasterSession::~MasterSession() {
1249   for (const auto& iter : run_graphs_) iter.second->Unref();
1250   for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1251 }
1252 
UpdateLastAccessTime()1253 void MasterSession::UpdateLastAccessTime() {
1254   last_access_time_usec_.store(Env::Default()->NowMicros());
1255 }
1256 
Create(GraphDef && graph_def,const WorkerCacheFactoryOptions & options)1257 Status MasterSession::Create(GraphDef&& graph_def,
1258                              const WorkerCacheFactoryOptions& options) {
1259   if (session_opts_.config.use_per_session_threads() ||
1260       session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1261     return errors::InvalidArgument(
1262         "Distributed session does not support session thread pool options.");
1263   }
1264   if (session_opts_.config.graph_options().place_pruned_graph()) {
1265     // TODO(b/29900832): Fix this or remove the option.
1266     LOG(WARNING) << "Distributed session does not support the "
1267                     "place_pruned_graph option.";
1268     session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1269   }
1270 
1271   GraphExecutionStateOptions execution_options;
1272   execution_options.device_set = devices_.get();
1273   execution_options.session_options = &session_opts_;
1274   {
1275     mutex_lock l(mu_);
1276     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1277         std::move(graph_def), execution_options, &execution_state_));
1278   }
1279   should_delete_worker_sessions_ = true;
1280   return CreateWorkerSessions(options);
1281 }
1282 
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1283 Status MasterSession::CreateWorkerSessions(
1284     const WorkerCacheFactoryOptions& options) {
1285   const std::vector<string> worker_names = filtered_worker_list_;
1286   WorkerCacheInterface* worker_cache = get_worker_cache();
1287 
1288   struct WorkerGroup {
1289     // The worker name. (Not owned.)
1290     const string* name;
1291 
1292     // The worker referenced by name. (Not owned.)
1293     WorkerInterface* worker = nullptr;
1294 
1295     // Request and responses used for a given worker.
1296     CreateWorkerSessionRequest request;
1297     CreateWorkerSessionResponse response;
1298     Status status = Status::OK();
1299   };
1300   BlockingCounter done(worker_names.size());
1301   std::vector<WorkerGroup> workers(worker_names.size());
1302 
1303   // Release the workers.
1304   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1305     for (auto&& worker_group : workers) {
1306       if (worker_group.worker != nullptr) {
1307         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1308       }
1309     }
1310   });
1311 
1312   string task_name;
1313   string local_device_name;
1314   DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1315                                    &task_name, &local_device_name);
1316   const int64_t client_device_incarnation =
1317       devices_->client_device()->attributes().incarnation();
1318 
1319   Status status = Status::OK();
1320   // Create all the workers & kick off the computations.
1321   for (size_t i = 0; i < worker_names.size(); ++i) {
1322     workers[i].name = &worker_names[i];
1323     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1324     workers[i].request.set_session_handle(handle_);
1325     workers[i].request.set_master_task(task_name);
1326     workers[i].request.set_master_incarnation(client_device_incarnation);
1327     if (session_opts_.config.share_cluster_devices_in_session() ||
1328         session_opts_.config.experimental()
1329             .share_cluster_devices_in_session()) {
1330       for (const auto& remote_dev : devices_->devices()) {
1331         *workers[i].request.add_cluster_device_attributes() =
1332             remote_dev->attributes();
1333       }
1334 
1335       if (!session_opts_.config.share_cluster_devices_in_session() &&
1336           session_opts_.config.experimental()
1337               .share_cluster_devices_in_session()) {
1338         LOG(WARNING)
1339             << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1340                "been promoted to a non-experimental API. Please use "
1341                "ConfigProto.share_cluster_devices_in_session instead. The "
1342                "experimental option will be removed in the future.";
1343       }
1344     }
1345 
1346     DeviceNameUtils::ParsedName name;
1347     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1348       status = errors::Internal("Could not parse name ", worker_names[i]);
1349       LOG(WARNING) << status;
1350       return status;
1351     }
1352     if (!name.has_job || !name.has_task) {
1353       status = errors::Internal("Incomplete worker name ", worker_names[i]);
1354       LOG(WARNING) << status;
1355       return status;
1356     }
1357 
1358     if (options.cluster_def) {
1359       *workers[i].request.mutable_server_def()->mutable_cluster() =
1360           *options.cluster_def;
1361       workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1362       workers[i].request.mutable_server_def()->set_job_name(name.job);
1363       workers[i].request.mutable_server_def()->set_task_index(name.task);
1364       // Session state is always isolated when ClusterSpec propagation
1365       // is in use.
1366       workers[i].request.set_isolate_session_state(true);
1367     } else {
1368       // NOTE(mrry): Do not set any component of the ServerDef,
1369       // because the worker will use its local configuration.
1370       workers[i].request.set_isolate_session_state(
1371           session_opts_.config.isolate_session_state());
1372     }
1373     if (session_opts_.config.experimental()
1374             .share_session_state_in_clusterspec_propagation()) {
1375       // In a dynamic cluster, the ClusterSpec info is usually propagated by
1376       // master sessions. However, in data parallel training with multiple
1377       // masters
1378       // ("between-graph replication"), we need to disable isolation for
1379       // different worker sessions to update the same variables in PS tasks.
1380       workers[i].request.set_isolate_session_state(false);
1381     }
1382   }
1383 
1384   for (size_t i = 0; i < worker_names.size(); ++i) {
1385     auto cb = [i, &workers, &done](const Status& s) {
1386       workers[i].status = s;
1387       done.DecrementCount();
1388     };
1389     workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1390                                                 &workers[i].response, cb);
1391   }
1392 
1393   done.Wait();
1394   for (size_t i = 0; i < workers.size(); ++i) {
1395     status.Update(workers[i].status);
1396   }
1397   return status;
1398 }
1399 
DeleteWorkerSessions()1400 Status MasterSession::DeleteWorkerSessions() {
1401   WorkerCacheInterface* worker_cache = get_worker_cache();
1402   const std::vector<string>& worker_names = filtered_worker_list_;
1403 
1404   struct WorkerGroup {
1405     // The worker name. (Not owned.)
1406     const string* name;
1407 
1408     // The worker referenced by name. (Not owned.)
1409     WorkerInterface* worker = nullptr;
1410 
1411     CallOptions call_opts;
1412 
1413     // Request and responses used for a given worker.
1414     DeleteWorkerSessionRequest request;
1415     DeleteWorkerSessionResponse response;
1416     Status status = Status::OK();
1417   };
1418   BlockingCounter done(worker_names.size());
1419   std::vector<WorkerGroup> workers(worker_names.size());
1420 
1421   // Release the workers.
1422   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1423     for (auto&& worker_group : workers) {
1424       if (worker_group.worker != nullptr) {
1425         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1426       }
1427     }
1428   });
1429 
1430   Status status = Status::OK();
1431   // Create all the workers & kick off the computations.
1432   for (size_t i = 0; i < worker_names.size(); ++i) {
1433     workers[i].name = &worker_names[i];
1434     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1435     workers[i].request.set_session_handle(handle_);
1436     // Since the worker may have gone away, set a timeout to avoid blocking the
1437     // session-close operation.
1438     workers[i].call_opts.SetTimeout(10000);
1439   }
1440 
1441   for (size_t i = 0; i < worker_names.size(); ++i) {
1442     auto cb = [i, &workers, &done](const Status& s) {
1443       workers[i].status = s;
1444       done.DecrementCount();
1445     };
1446     workers[i].worker->DeleteWorkerSessionAsync(
1447         &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1448   }
1449 
1450   done.Wait();
1451   for (size_t i = 0; i < workers.size(); ++i) {
1452     status.Update(workers[i].status);
1453   }
1454   return status;
1455 }
1456 
ListDevices(ListDevicesResponse * resp) const1457 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1458   if (worker_cache_) {
1459     // This is a ClusterSpec-propagated session, and thus env_->local_devices
1460     // are invalid.
1461 
1462     // Mark the "client_device" as the sole local device.
1463     const Device* client_device = devices_->client_device();
1464     for (const Device* dev : devices_->devices()) {
1465       if (dev != client_device) {
1466         *(resp->add_remote_device()) = dev->attributes();
1467       }
1468     }
1469     *(resp->add_local_device()) = client_device->attributes();
1470   } else {
1471     for (Device* dev : env_->local_devices) {
1472       *(resp->add_local_device()) = dev->attributes();
1473     }
1474     for (auto&& dev : *remote_devs_) {
1475       *(resp->add_local_device()) = dev->attributes();
1476     }
1477   }
1478   return Status::OK();
1479 }
1480 
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1481 Status MasterSession::Extend(const ExtendSessionRequest* req,
1482                              ExtendSessionResponse* resp) {
1483   UpdateLastAccessTime();
1484   std::unique_ptr<GraphExecutionState> extended_execution_state;
1485   {
1486     mutex_lock l(mu_);
1487     if (closed_) {
1488       return errors::FailedPrecondition("Session is closed.");
1489     }
1490 
1491     if (graph_version_ != req->current_graph_version()) {
1492       return errors::Aborted("Current version is ", graph_version_,
1493                              " but caller expected ",
1494                              req->current_graph_version(), ".");
1495     }
1496 
1497     CHECK(execution_state_);
1498     TF_RETURN_IF_ERROR(
1499         execution_state_->Extend(req->graph_def(), &extended_execution_state));
1500 
1501     CHECK(extended_execution_state);
1502     // The old execution state will be released outside the lock.
1503     execution_state_.swap(extended_execution_state);
1504     ++graph_version_;
1505     resp->set_new_graph_version(graph_version_);
1506   }
1507   return Status::OK();
1508 }
1509 
get_worker_cache() const1510 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1511   if (worker_cache_) {
1512     return worker_cache_.get();
1513   }
1514   return env_->worker_cache;
1515 }
1516 
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1517 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1518                                 ReffedClientGraph** out_rcg, int64* out_count) {
1519   const uint64 hash = HashBuildGraphOptions(opts);
1520   {
1521     mutex_lock l(mu_);
1522     // TODO(suharshs): We cache partial run graphs and run graphs separately
1523     // because there is preprocessing that needs to only be run for partial
1524     // run calls.
1525     RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1526     auto iter = m->find(hash);
1527     if (iter == m->end()) {
1528       // We have not seen this subgraph before. Build the subgraph and
1529       // cache it.
1530       VLOG(1) << "Unseen hash " << hash << " for "
1531               << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1532               << "\n";
1533       std::unique_ptr<ClientGraph> client_graph;
1534       TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1535       WorkerCacheInterface* worker_cache = get_worker_cache();
1536       auto entry = new ReffedClientGraph(
1537           handle_, opts, std::move(client_graph), session_opts_,
1538           stats_publisher_factory_, is_partial, worker_cache,
1539           !should_delete_worker_sessions_);
1540       iter = m->insert({hash, entry}).first;
1541       VLOG(1) << "Preparing to execute new graph";
1542     }
1543     *out_rcg = iter->second;
1544     (*out_rcg)->Ref();
1545     *out_count = (*out_rcg)->get_and_increment_execution_count();
1546   }
1547   return Status::OK();
1548 }
1549 
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1550 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1551                                    RCGMap* rcg_map) {
1552   VLOG(1) << "Discarding all reffed graphs";
1553   for (auto p : *rcg_map) {
1554     ReffedClientGraph* rcg = p.second;
1555     if (to_unref) {
1556       to_unref->push_back(rcg);
1557     } else {
1558       rcg->Unref();
1559     }
1560   }
1561   rcg_map->clear();
1562 }
1563 
NewStepId(int64_t graph_key)1564 uint64 MasterSession::NewStepId(int64_t graph_key) {
1565   if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1566     // StepId must leave the most-significant 7 bits empty for future use.
1567     return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1568   } else {
1569     uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1570     int32_t retry_count = 0;
1571     while (static_cast<int64>(step_id) == CollectiveExecutor::kInvalidId) {
1572       Notification note;
1573       Status status;
1574       env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1575           graph_key, [&status, &note](const Status& s) {
1576             status = s;
1577             note.Notify();
1578           });
1579       note.WaitForNotification();
1580       if (!status.ok()) {
1581         LOG(ERROR) << "Bad status from "
1582                       "collective_executor_mgr->RefreshStepIdSequence: "
1583                    << status << ".  Retrying.";
1584         int64_t delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1585         Env::Default()->SleepForMicroseconds(delay_micros);
1586       } else {
1587         step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1588       }
1589     }
1590     return step_id;
1591   }
1592 }
1593 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1594 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1595                                       PartialRunSetupResponse* resp) {
1596   std::vector<string> inputs, outputs, targets;
1597   for (const auto& feed : req->feed()) {
1598     inputs.push_back(feed);
1599   }
1600   for (const auto& fetch : req->fetch()) {
1601     outputs.push_back(fetch);
1602   }
1603   for (const auto& target : req->target()) {
1604     targets.push_back(target);
1605   }
1606 
1607   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1608 
1609   ReffedClientGraph* rcg = nullptr;
1610 
1611   // Prepare.
1612   BuildGraphOptions opts;
1613   BuildBuildGraphOptions(*req, &opts);
1614   int64_t count = 0;
1615   TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1616 
1617   rcg->Ref();
1618   RunState* run_state =
1619       new RunState(inputs, outputs, rcg,
1620                    NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1621   {
1622     mutex_lock l(mu_);
1623     partial_runs_.emplace(
1624         std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1625   }
1626 
1627   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1628 
1629   resp->set_partial_run_handle(handle);
1630   return Status::OK();
1631 }
1632 
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1633 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1634                           MutableRunStepResponseWrapper* resp) {
1635   UpdateLastAccessTime();
1636   {
1637     mutex_lock l(mu_);
1638     if (closed_) {
1639       return errors::FailedPrecondition("Session is closed.");
1640     }
1641     ++num_running_;
1642     // Note: all code paths must eventually call MarkRunCompletion()
1643     // in order to appropriate decrement the num_running_ counter.
1644   }
1645   Status status;
1646   if (!req.partial_run_handle().empty()) {
1647     status = DoPartialRun(opts, req, resp);
1648   } else {
1649     status = DoRunWithLocalExecution(opts, req, resp);
1650   }
1651   return status;
1652 }
1653 
1654 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1655 void MasterSession::MarkRunCompletion() {
1656   mutex_lock l(mu_);
1657   --num_running_;
1658   if (num_running_ == 0) {
1659     num_running_is_zero_.notify_all();
1660   }
1661 }
1662 
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1663 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1664   // Registers subgraphs if haven't done so.
1665   PartitionOptions popts;
1666   popts.node_to_loc = SplitByWorker;
1667   // The closures popts.{new_name,get_incarnation} are called synchronously in
1668   // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1669   // "this" alive during the closure.
1670   popts.new_name = [this](const string& prefix) {
1671     mutex_lock l(mu_);
1672     return strings::StrCat(prefix, "_S", next_node_id_++);
1673   };
1674   popts.get_incarnation = [this](const string& name) -> int64 {
1675     Device* d = devices_->FindDeviceByName(name);
1676     if (d == nullptr) {
1677       return PartitionOptions::kIllegalIncarnation;
1678     } else {
1679       return d->attributes().incarnation();
1680     }
1681   };
1682   popts.control_flow_added = false;
1683   const bool enable_bfloat16_sendrecv =
1684       session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1685   popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1686     if (e->IsControlEdge()) {
1687       return DT_FLOAT;
1688     }
1689     DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1690     if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1691       return DT_BFLOAT16;
1692     } else {
1693       return dtype;
1694     }
1695   };
1696   if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1697     popts.scheduling_for_recvs = true;
1698     popts.need_to_record_start_times = true;
1699   }
1700 
1701   TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1702 
1703   return Status::OK();
1704 }
1705 
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1706 Status MasterSession::DoPartialRun(CallOptions* opts,
1707                                    const RunStepRequestWrapper& req,
1708                                    MutableRunStepResponseWrapper* resp) {
1709   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1710   const string& prun_handle = req.partial_run_handle();
1711   RunState* run_state = nullptr;
1712   {
1713     mutex_lock l(mu_);
1714     auto it = partial_runs_.find(prun_handle);
1715     if (it == partial_runs_.end()) {
1716       return errors::InvalidArgument(
1717           "Must run PartialRunSetup before performing partial runs");
1718     }
1719     run_state = it->second.get();
1720   }
1721   // CollectiveOps are not supported in partial runs.
1722   if (req.options().experimental().collective_graph_key() !=
1723       BuildGraphOptions::kNoCollectiveGraphKey) {
1724     return errors::InvalidArgument(
1725         "PartialRun does not support Collective ops.  collective_graph_key "
1726         "must be kNoCollectiveGraphKey.");
1727   }
1728 
1729   // If this is the first partial run, initialize the PerStepState.
1730   if (!run_state->step_started) {
1731     run_state->step_started = true;
1732     PerStepState pss;
1733 
1734     const auto count = run_state->count;
1735     pss.collect_timeline =
1736         req.options().trace_level() == RunOptions::FULL_TRACE;
1737     pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1738     pss.report_tensor_allocations_upon_oom =
1739         req.options().report_tensor_allocations_upon_oom();
1740 
1741     // Build the cost model every 'build_cost_model_every' steps after skipping
1742     // an
1743     // initial 'build_cost_model_after' steps.
1744     const int64_t build_cost_model_after =
1745         session_opts_.config.graph_options().build_cost_model_after();
1746     const int64_t build_cost_model_every =
1747         session_opts_.config.graph_options().build_cost_model();
1748     pss.collect_costs =
1749         build_cost_model_every > 0 &&
1750         ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1751     pss.collect_partition_graphs = req.options().output_partition_graphs();
1752 
1753     std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1754         run_state->step_id, count, req.options());
1755     if (ph) {
1756       pss.collect_timeline = true;
1757       pss.collect_rpcs = ph->should_collect_rpcs();
1758     }
1759 
1760     run_state->pss = std::move(pss);
1761     run_state->ph = std::move(ph);
1762   }
1763 
1764   // Make sure that this is a new set of feeds that are still pending.
1765   for (size_t i = 0; i < req.num_feeds(); ++i) {
1766     const string& feed = req.feed_name(i);
1767     auto it = run_state->pending_inputs.find(feed);
1768     if (it == run_state->pending_inputs.end()) {
1769       return errors::InvalidArgument(
1770           "The feed ", feed, " was not specified in partial_run_setup.");
1771     } else if (it->second) {
1772       return errors::InvalidArgument("The feed ", feed,
1773                                      " has already been fed.");
1774     }
1775   }
1776   // Check that this is a new set of fetches that are still pending.
1777   for (size_t i = 0; i < req.num_fetches(); ++i) {
1778     const string& fetch = req.fetch_name(i);
1779     auto it = run_state->pending_outputs.find(fetch);
1780     if (it == run_state->pending_outputs.end()) {
1781       return errors::InvalidArgument(
1782           "The fetch ", fetch, " was not specified in partial_run_setup.");
1783     } else if (it->second) {
1784       return errors::InvalidArgument("The fetch ", fetch,
1785                                      " has already been fetched.");
1786     }
1787   }
1788 
1789   // Ensure that the requested fetches can be computed from the provided feeds.
1790   {
1791     mutex_lock l(mu_);
1792     TF_RETURN_IF_ERROR(
1793         run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1794   }
1795 
1796   // Determine if this partial run satisfies all the pending inputs and outputs.
1797   for (size_t i = 0; i < req.num_feeds(); ++i) {
1798     auto it = run_state->pending_inputs.find(req.feed_name(i));
1799     it->second = true;
1800   }
1801   for (size_t i = 0; i < req.num_fetches(); ++i) {
1802     auto it = run_state->pending_outputs.find(req.fetch_name(i));
1803     it->second = true;
1804   }
1805   bool is_last_partial_run = run_state->PendingDone();
1806 
1807   Status s = run_state->rcg->RunPartitions(
1808       env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1809       resp, &cancellation_manager_, is_last_partial_run);
1810 
1811   // Delete the run state if there is an error or all fetches are done.
1812   if (!s.ok() || is_last_partial_run) {
1813     ReffedClientGraph* rcg = run_state->rcg;
1814     run_state->pss.end_micros = Env::Default()->NowMicros();
1815     // Schedule post-processing and cleanup to be done asynchronously.
1816     Ref();
1817     rcg->Ref();
1818     rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1819                       req.options(), resp->mutable_metadata());
1820     cleanup.release();  // MarkRunCompletion called in done closure.
1821     rcg->CleanupPartitionsAsync(
1822         run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1823           if (!s.ok()) {
1824             LOG(ERROR) << "Cleanup partition error: " << s;
1825           }
1826           rcg->Unref();
1827           MarkRunCompletion();
1828           Unref();
1829         });
1830     mutex_lock l(mu_);
1831     partial_runs_.erase(prun_handle);
1832   }
1833   return s;
1834 }
1835 
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64_t rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1836 Status MasterSession::CreateDebuggerState(
1837     const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1838     int64_t rcg_execution_count,
1839     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1840   TF_RETURN_IF_ERROR(
1841       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1842 
1843   std::vector<string> input_names;
1844   for (size_t i = 0; i < req.num_feeds(); ++i) {
1845     input_names.push_back(req.feed_name(i));
1846   }
1847   std::vector<string> output_names;
1848   for (size_t i = 0; i < req.num_fetches(); ++i) {
1849     output_names.push_back(req.fetch_name(i));
1850   }
1851   std::vector<string> target_names;
1852   for (size_t i = 0; i < req.num_targets(); ++i) {
1853     target_names.push_back(req.target_name(i));
1854   }
1855 
1856   // TODO(cais): We currently use -1 as a dummy value for session run count.
1857   // While this counter value is straightforward to define and obtain for
1858   // DirectSessions, it is less so for non-direct Sessions. Devise a better
1859   // way to get its value when the need arises.
1860   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1861       debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1862       input_names, output_names, target_names));
1863 
1864   return Status::OK();
1865 }
1866 
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64_t count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1867 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1868                                      const RunOptions& run_options,
1869                                      uint64 step_id, int64_t count,
1870                                      PerStepState* out_pss,
1871                                      std::unique_ptr<ProfileHandler>* out_ph) {
1872   out_pss->collect_timeline =
1873       run_options.trace_level() == RunOptions::FULL_TRACE;
1874   out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1875   out_pss->report_tensor_allocations_upon_oom =
1876       run_options.report_tensor_allocations_upon_oom();
1877   // Build the cost model every 'build_cost_model_every' steps after skipping an
1878   // initial 'build_cost_model_after' steps.
1879   const int64_t build_cost_model_after =
1880       session_opts_.config.graph_options().build_cost_model_after();
1881   const int64_t build_cost_model_every =
1882       session_opts_.config.graph_options().build_cost_model();
1883   out_pss->collect_costs =
1884       build_cost_model_every > 0 &&
1885       ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1886   out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1887 
1888   *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1889   if (*out_ph) {
1890     out_pss->collect_timeline = true;
1891     out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1892   }
1893 }
1894 
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1895 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1896                                      uint64 step_id,
1897                                      const RunOptions& run_options,
1898                                      PerStepState* pss,
1899                                      const std::unique_ptr<ProfileHandler>& ph,
1900                                      const Status& run_status,
1901                                      RunMetadata* out_run_metadata) {
1902   Status s = run_status;
1903   if (s.ok()) {
1904     pss->end_micros = Env::Default()->NowMicros();
1905     if (rcg->collective_graph_key() !=
1906         BuildGraphOptions::kNoCollectiveGraphKey) {
1907       env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1908                                                   step_id);
1909     }
1910     // Schedule post-processing and cleanup to be done asynchronously.
1911     rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1912   } else if (errors::IsCancelled(s)) {
1913     mutex_lock l(mu_);
1914     if (closed_) {
1915       if (garbage_collected_) {
1916         s = errors::Cancelled(
1917             "Step was cancelled because the session was garbage collected due "
1918             "to inactivity.");
1919       } else {
1920         s = errors::Cancelled(
1921             "Step was cancelled by an explicit call to `Session::Close()`.");
1922       }
1923     }
1924   }
1925   Ref();
1926   rcg->Ref();
1927   rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1928     if (!s.ok()) {
1929       LOG(ERROR) << "Cleanup partition error: " << s;
1930     }
1931     rcg->Unref();
1932     MarkRunCompletion();
1933     Unref();
1934   });
1935   return s;
1936 }
1937 
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1938 Status MasterSession::DoRunWithLocalExecution(
1939     CallOptions* opts, const RunStepRequestWrapper& req,
1940     MutableRunStepResponseWrapper* resp) {
1941   VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1942   PerStepState pss;
1943   pss.start_micros = Env::Default()->NowMicros();
1944   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1945 
1946   // Prepare.
1947   BuildGraphOptions bgopts;
1948   BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1949   ReffedClientGraph* rcg = nullptr;
1950   int64_t count;
1951   TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1952 
1953   // Unref "rcg" when out of scope.
1954   core::ScopedUnref unref(rcg);
1955 
1956   std::unique_ptr<DebuggerStateInterface> debugger_state;
1957   const DebugOptions& debug_options = req.options().debug_options();
1958 
1959   if (!debug_options.debug_tensor_watch_opts().empty()) {
1960     TF_RETURN_IF_ERROR(
1961         CreateDebuggerState(debug_options, req, count, &debugger_state));
1962   }
1963   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1964 
1965   // Keeps the highest 8 bits 0x01: we reserve some bits of the
1966   // step_id for future use.
1967   uint64 step_id = NewStepId(rcg->collective_graph_key());
1968   TRACEPRINTF("stepid %llu", step_id);
1969 
1970   std::unique_ptr<ProfileHandler> ph;
1971   FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1972 
1973   if (pss.collect_partition_graphs &&
1974       session_opts_.config.experimental().disable_output_partition_graphs()) {
1975     return errors::InvalidArgument(
1976         "RunOptions.output_partition_graphs() is not supported when "
1977         "disable_output_partition_graphs is true.");
1978   }
1979 
1980   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1981                                 &cancellation_manager_, false);
1982 
1983   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
1984   return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1985                         resp->mutable_metadata());
1986 }
1987 
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1988 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1989                                    MakeCallableResponse* resp) {
1990   UpdateLastAccessTime();
1991 
1992   BuildGraphOptions opts;
1993   opts.callable_options = req.options();
1994   opts.use_function_convention = false;
1995 
1996   ReffedClientGraph* callable;
1997 
1998   {
1999     mutex_lock l(mu_);
2000     if (closed_) {
2001       return errors::FailedPrecondition("Session is closed.");
2002     }
2003     std::unique_ptr<ClientGraph> client_graph;
2004     TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2005     callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2006                                      session_opts_, stats_publisher_factory_,
2007                                      false /* is_partial */, get_worker_cache(),
2008                                      !should_delete_worker_sessions_);
2009   }
2010 
2011   Status s = BuildAndRegisterPartitions(callable);
2012   if (!s.ok()) {
2013     callable->Unref();
2014     return s;
2015   }
2016 
2017   uint64 handle;
2018   {
2019     mutex_lock l(mu_);
2020     handle = next_callable_handle_++;
2021     callables_[handle] = callable;
2022   }
2023 
2024   resp->set_handle(handle);
2025   return Status::OK();
2026 }
2027 
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)2028 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2029                                     const RunCallableRequest& req,
2030                                     RunCallableResponse* resp) {
2031   VLOG(2) << "DoRunCallable req: " << req.DebugString();
2032   PerStepState pss;
2033   pss.start_micros = Env::Default()->NowMicros();
2034   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2035 
2036   // Prepare.
2037   int64_t count = rcg->get_and_increment_execution_count();
2038 
2039   const uint64 step_id = NewStepId(rcg->collective_graph_key());
2040   TRACEPRINTF("stepid %llu", step_id);
2041 
2042   const RunOptions& run_options = rcg->callable_options().run_options();
2043 
2044   if (run_options.timeout_in_ms() != 0) {
2045     opts->SetTimeout(run_options.timeout_in_ms());
2046   }
2047 
2048   std::unique_ptr<ProfileHandler> ph;
2049   FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2050   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2051                                 &cancellation_manager_);
2052   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
2053   return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2054                         resp->mutable_metadata());
2055 }
2056 
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)2057 Status MasterSession::RunCallable(CallOptions* opts,
2058                                   const RunCallableRequest& req,
2059                                   RunCallableResponse* resp) {
2060   UpdateLastAccessTime();
2061   ReffedClientGraph* callable;
2062   {
2063     mutex_lock l(mu_);
2064     if (closed_) {
2065       return errors::FailedPrecondition("Session is closed.");
2066     }
2067     int64_t handle = req.handle();
2068     if (handle >= next_callable_handle_) {
2069       return errors::InvalidArgument("No such callable handle: ", handle);
2070     }
2071     auto iter = callables_.find(req.handle());
2072     if (iter == callables_.end()) {
2073       return errors::InvalidArgument(
2074           "Attempted to run callable after handle was released: ", handle);
2075     }
2076     callable = iter->second;
2077     callable->Ref();
2078     ++num_running_;
2079   }
2080   core::ScopedUnref unref_callable(callable);
2081   return DoRunCallable(opts, callable, req, resp);
2082 }
2083 
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)2084 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2085                                       ReleaseCallableResponse* resp) {
2086   UpdateLastAccessTime();
2087   ReffedClientGraph* to_unref = nullptr;
2088   {
2089     mutex_lock l(mu_);
2090     auto iter = callables_.find(req.handle());
2091     if (iter != callables_.end()) {
2092       to_unref = iter->second;
2093       callables_.erase(iter);
2094     }
2095   }
2096   if (to_unref != nullptr) {
2097     to_unref->Unref();
2098   }
2099   return Status::OK();
2100 }
2101 
Close()2102 Status MasterSession::Close() {
2103   {
2104     mutex_lock l(mu_);
2105     closed_ = true;  // All subsequent calls to Run() or Extend() will fail.
2106   }
2107   cancellation_manager_.StartCancel();
2108   std::vector<ReffedClientGraph*> to_unref;
2109   {
2110     mutex_lock l(mu_);
2111     while (num_running_ != 0) {
2112       num_running_is_zero_.wait(l);
2113     }
2114     ClearRunsTable(&to_unref, &run_graphs_);
2115     ClearRunsTable(&to_unref, &partial_run_graphs_);
2116     ClearRunsTable(&to_unref, &callables_);
2117   }
2118   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2119   if (should_delete_worker_sessions_) {
2120     Status s = DeleteWorkerSessions();
2121     if (!s.ok()) {
2122       LOG(WARNING) << s;
2123     }
2124   }
2125   return Status::OK();
2126 }
2127 
GarbageCollect()2128 void MasterSession::GarbageCollect() {
2129   {
2130     mutex_lock l(mu_);
2131     closed_ = true;
2132     garbage_collected_ = true;
2133   }
2134   cancellation_manager_.StartCancel();
2135   Unref();
2136 }
2137 
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64_t count)2138 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2139                                   const std::vector<string>& output_names,
2140                                   ReffedClientGraph* rcg, const uint64 step_id,
2141                                   const int64_t count)
2142     : rcg(rcg), step_id(step_id), count(count) {
2143   // Initially all the feeds and fetches are pending.
2144   for (auto& name : input_names) {
2145     pending_inputs[name] = false;
2146   }
2147   for (auto& name : output_names) {
2148     pending_outputs[name] = false;
2149   }
2150 }
2151 
~RunState()2152 MasterSession::RunState::~RunState() {
2153   if (rcg) rcg->Unref();
2154 }
2155 
PendingDone() const2156 bool MasterSession::RunState::PendingDone() const {
2157   for (const auto& it : pending_inputs) {
2158     if (!it.second) return false;
2159   }
2160   for (const auto& it : pending_outputs) {
2161     if (!it.second) return false;
2162   }
2163   return true;
2164 }
2165 
2166 }  // end namespace tensorflow
2167