• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/common_runtime/profile_handler.h"
25 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26 #include "tensorflow/core/debug/debug_graph_utils.h"
27 #include "tensorflow/core/distributed_runtime/scheduler.h"
28 #include "tensorflow/core/distributed_runtime/worker_cache.h"
29 #include "tensorflow/core/distributed_runtime/worker_interface.h"
30 #include "tensorflow/core/framework/allocation_description.pb.h"
31 #include "tensorflow/core/framework/collective.h"
32 #include "tensorflow/core/framework/cost_graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/tensor_description.pb.h"
37 #include "tensorflow/core/graph/graph_partition.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/blocking_counter.h"
40 #include "tensorflow/core/lib/core/notification.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/gtl/map_util.h"
46 #include "tensorflow/core/lib/random/random.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/macros.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/public/session_options.h"
57 
58 namespace tensorflow {
59 
60 // MasterSession wraps ClientGraph in a reference counted object.
61 // This way, MasterSession can clear up the cache mapping Run requests to
62 // compiled graphs while the compiled graph is still being used.
63 //
64 // TODO(zhifengc): Cleanup this class. It's becoming messy.
65 class MasterSession::ReffedClientGraph : public core::RefCounted {
66  public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)67   ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
68                     std::unique_ptr<ClientGraph> client_graph,
69                     const SessionOptions& session_opts,
70                     const StatsPublisherFactory& stats_publisher_factory,
71                     bool is_partial, WorkerCacheInterface* worker_cache,
72                     bool should_deregister)
73       : session_handle_(handle),
74         bg_opts_(bopts),
75         client_graph_before_register_(std::move(client_graph)),
76         session_opts_(session_opts),
77         is_partial_(is_partial),
78         callable_opts_(bopts.callable_options),
79         worker_cache_(worker_cache),
80         should_deregister_(should_deregister),
81         collective_graph_key_(
82             client_graph_before_register_->collective_graph_key) {
83     VLOG(1) << "Created ReffedClientGraph for node with "
84             << client_graph_before_register_->graph.num_node_ids();
85 
86     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
87 
88     // Initialize a name to node map for processing device stats.
89     for (Node* n : client_graph_before_register_->graph.nodes()) {
90       name_to_node_details_.emplace(
91           n->name(),
92           NodeDetails(n->type_string(),
93                       strings::StrCat(
94                           "(", str_util::Join(n->requested_inputs(), ", "))));
95     }
96   }
97 
~ReffedClientGraph()98   ~ReffedClientGraph() override {
99     if (should_deregister_) {
100       DeregisterPartitions();
101     } else {
102       for (Part& part : partitions_) {
103         worker_cache_->ReleaseWorker(part.name, part.worker);
104       }
105     }
106   }
107 
callable_options()108   const CallableOptions& callable_options() { return callable_opts_; }
109 
build_graph_options()110   const BuildGraphOptions& build_graph_options() { return bg_opts_; }
111 
collective_graph_key()112   int64 collective_graph_key() { return collective_graph_key_; }
113 
GetProfileHandler(uint64 step,int64 execution_count,const RunOptions & ropts)114   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
115                                                     int64 execution_count,
116                                                     const RunOptions& ropts) {
117     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
118   }
119 
get_and_increment_execution_count()120   int64 get_and_increment_execution_count() {
121     return execution_count_.fetch_add(1);
122   }
123 
124   // Turn RPC logging on or off, both at the WorkerCache used by this
125   // master process, and at each remote worker in use for the current
126   // partitions.
SetRPCLogging(bool active)127   void SetRPCLogging(bool active) {
128     worker_cache_->SetLogging(active);
129     // Logging is a best-effort activity, so we make async calls to turn
130     // it on/off and don't make use of the responses.
131     for (auto& p : partitions_) {
132       LoggingRequest* req = new LoggingRequest;
133       if (active) {
134         req->set_enable_rpc_logging(true);
135       } else {
136         req->set_disable_rpc_logging(true);
137       }
138       LoggingResponse* resp = new LoggingResponse;
139       Ref();
140       p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
141         delete req;
142         delete resp;
143         // ReffedClientGraph owns p.worker so we need to hold a ref to
144         // ensure that the method doesn't attempt to access p.worker after
145         // ReffedClient graph has deleted it.
146         // TODO(suharshs): Simplify this ownership model.
147         Unref();
148       });
149     }
150   }
151 
152   // Retrieve all RPC logs data accumulated for the current step, both
153   // from the local WorkerCache in use by this master process and from
154   // all the remote workers executing the remote partitions.
RetrieveLogs(int64 step_id,StepStats * ss)155   void RetrieveLogs(int64 step_id, StepStats* ss) {
156     // Get the local data first, because it sets *ss without merging.
157     worker_cache_->RetrieveLogs(step_id, ss);
158 
159     // Then merge in data from all the remote workers.
160     LoggingRequest req;
161     req.add_fetch_step_id(step_id);
162     int waiting_for = partitions_.size();
163     if (waiting_for > 0) {
164       mutex scoped_mu;
165       BlockingCounter all_done(waiting_for);
166       for (auto& p : partitions_) {
167         LoggingResponse* resp = new LoggingResponse;
168         p.worker->LoggingAsync(
169             &req, resp,
170             [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
171               {
172                 mutex_lock l(scoped_mu);
173                 if (s.ok()) {
174                   for (auto& lss : resp->step()) {
175                     if (step_id != lss.step_id()) {
176                       LOG(ERROR) << "Wrong step_id in LoggingResponse";
177                       continue;
178                     }
179                     ss->MergeFrom(lss.step_stats());
180                   }
181                 }
182                 delete resp;
183               }
184               // Must not decrement all_done until out of critical section where
185               // *ss is updated.
186               all_done.DecrementCount();
187             });
188       }
189       all_done.Wait();
190     }
191   }
192 
193   // Local execution methods.
194 
195   // Partitions the graph into subgraphs and registers them on
196   // workers.
197   Status RegisterPartitions(PartitionOptions popts);
198 
199   // Runs one step of all partitions.
200   Status RunPartitions(const MasterEnv* env, int64 step_id,
201                        int64 execution_count, PerStepState* pss,
202                        CallOptions* opts, const RunStepRequestWrapper& req,
203                        MutableRunStepResponseWrapper* resp,
204                        CancellationManager* cm, const bool is_last_partial_run);
205   Status RunPartitions(const MasterEnv* env, int64 step_id,
206                        int64 execution_count, PerStepState* pss,
207                        CallOptions* call_opts, const RunCallableRequest& req,
208                        RunCallableResponse* resp, CancellationManager* cm);
209 
210   // Calls workers to cleanup states for the step "step_id".  Calls
211   // `done` when all cleanup RPCs have completed.
212   void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
213 
214   // Post-processing of any runtime statistics gathered during execution.
215   void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
216                     const RunOptions& options, RunMetadata* resp);
217   void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
218                           bool is_rpc);
219   // Checks that the requested fetches can be computed from the provided feeds.
220   Status CheckFetches(const RunStepRequestWrapper& req,
221                       const RunState* run_state,
222                       GraphExecutionState* execution_state);
223 
224  private:
225   const string session_handle_;
226   const BuildGraphOptions bg_opts_;
227 
228   // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
229   std::unique_ptr<ClientGraph> client_graph_before_register_ GUARDED_BY(mu_);
230   const SessionOptions session_opts_;
231   const bool is_partial_;
232   const CallableOptions callable_opts_;
233   WorkerCacheInterface* const worker_cache_;  // Not owned.
234 
235   struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails236     explicit NodeDetails(string type_string, string detail_text)
237         : type_string(std::move(type_string)),
238           detail_text(std::move(detail_text)) {}
239     const string type_string;
240     const string detail_text;
241   };
242   std::unordered_map<string, NodeDetails> name_to_node_details_;
243 
244   const bool should_deregister_;
245   const int64 collective_graph_key_;
246   std::atomic<int64> execution_count_ = {0};
247 
248   // Graph partitioned into per-location subgraphs.
249   struct Part {
250     // Worker name.
251     string name;
252 
253     // Maps feed names to rendezvous keys. Empty most of the time.
254     std::unordered_map<string, string> feed_key;
255 
256     // Maps rendezvous keys to fetch names. Empty most of the time.
257     std::unordered_map<string, string> key_fetch;
258 
259     // The interface to the worker. Owned.
260     WorkerInterface* worker = nullptr;
261 
262     // After registeration with the worker, graph_handle identifies
263     // this partition on the worker.
264     string graph_handle;
265 
Parttensorflow::MasterSession::ReffedClientGraph::Part266     Part() : feed_key(3), key_fetch(3) {}
267   };
268 
269   // partitions_ is immutable after RegisterPartitions() call
270   // finishes.  RunPartitions() can access partitions_ safely without
271   // acquiring locks.
272   std::vector<Part> partitions_;
273 
274   mutable mutex mu_;
275 
276   // Partition initialization and registration only needs to happen
277   // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
278   // indicates the initialization is ongoing.
279   Notification init_done_;
280 
281   // init_result_ remembers the initialization error if any.
282   Status init_result_ GUARDED_BY(mu_);
283 
284   std::unique_ptr<StatsPublisherInterface> stats_publisher_;
285 
DetailText(const NodeDetails & details,const NodeExecStats & stats)286   string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
287     int64 tot = 0;
288     for (auto& no : stats.output()) {
289       tot += no.tensor_description().allocation_description().requested_bytes();
290     }
291     string bytes;
292     if (tot >= 0.1 * 1048576.0) {
293       bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
294     }
295     return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
296                            details.detail_text);
297   }
298 
299   // Send/Recv nodes that are the result of client-added
300   // feeds and fetches must be tracked so that the tensors
301   // can be added to the local rendezvous.
302   static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
303                                    const PartitionOptions& popts);
304 
305   // The actual graph partitioning and registration implementation.
306   Status DoBuildPartitions(
307       PartitionOptions popts, ClientGraph* client_graph,
308       std::unordered_map<string, GraphDef>* out_partitions);
309   Status DoRegisterPartitions(
310       const PartitionOptions& popts,
311       std::unordered_map<string, GraphDef> graph_partitions);
312 
313   // Prepares a number of calls to workers. One call per partition.
314   // This is a generic method that handles Run, PartialRun, and RunCallable.
315   template <class FetchListType, class ClientRequestType,
316             class ClientResponseType>
317   Status RunPartitionsHelper(
318       const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
319       const FetchListType& fetches, const MasterEnv* env, int64 step_id,
320       int64 execution_count, PerStepState* pss, CallOptions* call_opts,
321       const ClientRequestType& req, ClientResponseType* resp,
322       CancellationManager* cm, bool is_last_partial_run);
323 
324   // Deregisters the partitions on the workers.  Called in the
325   // destructor and does not wait for the rpc completion.
326   void DeregisterPartitions();
327 
328   TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
329 };
330 
RegisterPartitions(PartitionOptions popts)331 Status MasterSession::ReffedClientGraph::RegisterPartitions(
332     PartitionOptions popts) {
333   {  // Ensure register once.
334     mu_.lock();
335     if (client_graph_before_register_) {
336       // The `ClientGraph` is no longer needed after partitions are registered.
337       // Since it can account for a large amount of memory, we consume it here,
338       // and it will be freed after concluding with registration.
339 
340       std::unique_ptr<ClientGraph> client_graph;
341       std::swap(client_graph_before_register_, client_graph);
342       mu_.unlock();
343       std::unordered_map<string, GraphDef> graph_defs;
344       popts.flib_def = client_graph->flib_def.get();
345       Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
346       if (s.ok()) {
347         // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
348         // valid after the call to DoRegisterPartitions begins, so
349         // `stats_publisher_` must make a copy if it wants to retain the
350         // GraphDef objects.
351         std::vector<const GraphDef*> graph_defs_for_publishing;
352         graph_defs_for_publishing.reserve(partitions_.size());
353         for (const auto& name_def : graph_defs) {
354           graph_defs_for_publishing.push_back(&name_def.second);
355         }
356         stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
357         s = DoRegisterPartitions(popts, std::move(graph_defs));
358       }
359       mu_.lock();
360       init_result_ = s;
361       init_done_.Notify();
362     } else {
363       mu_.unlock();
364       init_done_.WaitForNotification();
365       mu_.lock();
366     }
367     const Status result = init_result_;
368     mu_.unlock();
369     return result;
370   }
371 }
372 
SplitByWorker(const Node * node)373 static string SplitByWorker(const Node* node) {
374   string task;
375   string device;
376   CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
377                                          &device))
378       << "node: " << node->name() << " dev: " << node->assigned_device_name();
379   return task;
380 }
381 
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)382 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
383     Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
384   for (int i = 0; i < graph_def.node_size(); ++i) {
385     const NodeDef& ndef = graph_def.node(i);
386     const bool is_recv = ndef.op() == "_Recv";
387     const bool is_send = ndef.op() == "_Send";
388 
389     if (is_recv || is_send) {
390       // Only send/recv nodes that were added as feeds and fetches
391       // (client-terminated) should be tracked.  Other send/recv nodes
392       // are for transferring data between partitions / memory spaces.
393       bool client_terminated;
394       TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
395       if (client_terminated) {
396         string name;
397         TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
398         string send_device;
399         TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
400         string recv_device;
401         TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
402         uint64 send_device_incarnation;
403         TF_CHECK_OK(
404             GetNodeAttr(ndef, "send_device_incarnation",
405                         reinterpret_cast<int64*>(&send_device_incarnation)));
406         const string& key =
407             Rendezvous::CreateKey(send_device, send_device_incarnation,
408                                   recv_device, name, FrameAndIter(0, 0));
409 
410         if (is_recv) {
411           part->feed_key.insert({name, key});
412         } else {
413           part->key_fetch.insert({key, name});
414         }
415       }
416     }
417   }
418 }
419 
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)420 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
421     PartitionOptions popts, ClientGraph* client_graph,
422     std::unordered_map<string, GraphDef>* out_partitions) {
423   if (popts.need_to_record_start_times) {
424     CostModel cost_model(true);
425     cost_model.InitFromGraph(client_graph->graph);
426     // TODO(yuanbyu): Use the real cost model.
427     // execution_state_->MergeFromGlobal(&cost_model);
428     SlackAnalysis sa(&client_graph->graph, &cost_model);
429     sa.ComputeAsap(&popts.start_times);
430   }
431 
432   // Partition the graph.
433   return Partition(popts, &client_graph->graph, out_partitions);
434 }
435 
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)436 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
437     const PartitionOptions& popts,
438     std::unordered_map<string, GraphDef> graph_partitions) {
439   partitions_.reserve(graph_partitions.size());
440   Status s;
441   for (auto& name_def : graph_partitions) {
442     partitions_.emplace_back();
443     Part* part = &partitions_.back();
444     part->name = name_def.first;
445     TrackFeedsAndFetches(part, name_def.second, popts);
446     part->worker = worker_cache_->CreateWorker(part->name);
447     if (part->worker == nullptr) {
448       s = errors::NotFound("worker ", part->name);
449       break;
450     }
451   }
452   if (!s.ok()) {
453     for (Part& part : partitions_) {
454       worker_cache_->ReleaseWorker(part.name, part.worker);
455       part.worker = nullptr;
456     }
457     return s;
458   }
459   struct Call {
460     RegisterGraphRequest req;
461     RegisterGraphResponse resp;
462     Status status;
463   };
464   const int num = partitions_.size();
465   gtl::InlinedVector<Call, 4> calls(num);
466   BlockingCounter done(num);
467   for (int i = 0; i < num; ++i) {
468     const Part& part = partitions_[i];
469     Call* c = &calls[i];
470     c->req.set_session_handle(session_handle_);
471     c->req.set_create_worker_session_called(!should_deregister_);
472     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
473     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
474     *c->req.mutable_debug_options() =
475         callable_opts_.run_options().debug_options();
476     c->req.set_collective_graph_key(collective_graph_key_);
477     VLOG(2) << "Register " << c->req.graph_def().DebugString();
478     auto cb = [c, &done](const Status& s) {
479       c->status = s;
480       done.DecrementCount();
481     };
482     part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
483   }
484   done.Wait();
485   for (int i = 0; i < num; ++i) {
486     Call* c = &calls[i];
487     s.Update(c->status);
488     partitions_[i].graph_handle = c->resp.graph_handle();
489   }
490   return s;
491 }
492 
493 // Helper class to manage "num" parallel RunGraph calls.
494 class RunManyGraphs {
495  public:
RunManyGraphs(int num)496   explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
497 
~RunManyGraphs()498   ~RunManyGraphs() {}
499 
500   // Returns the index-th call.
501   struct Call {
502     CallOptions opts;
503     std::unique_ptr<MutableRunGraphRequestWrapper> req;
504     std::unique_ptr<MutableRunGraphResponseWrapper> resp;
505   };
get(int index)506   Call* get(int index) { return &calls_[index]; }
507 
508   // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)509   void WhenDone(int index, const Status& s) {
510     TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
511     auto resp = get(index)->resp.get();
512     if (resp->status_code() != error::Code::OK) {
513       // resp->status_code will only be non-OK if s.ok().
514       mutex_lock l(mu_);
515       ReportBadStatus(
516           Status(resp->status_code(), resp->status_error_message()));
517     } else if (!s.ok()) {
518       mutex_lock l(mu_);
519       ReportBadStatus(s);
520     }
521     pending_.DecrementCount();
522   }
523 
StartCancel()524   void StartCancel() {
525     mutex_lock l(mu_);
526     ReportBadStatus(errors::Cancelled("RunManyGraphs"));
527   }
528 
Wait()529   void Wait() { pending_.Wait(); }
530 
status() const531   Status status() const {
532     mutex_lock l(mu_);
533     return status_group_.as_status();
534   }
535 
536  private:
537   gtl::InlinedVector<Call, 4> calls_;
538 
539   BlockingCounter pending_;
540   mutable mutex mu_;
541   StatusGroup status_group_ GUARDED_BY(mu_);
542 
ReportBadStatus(const Status & s)543   void ReportBadStatus(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
544     // Start cancellation if we aren't already in an error state.
545     if (status_group_.ok()) {
546       for (Call& call : calls_) {
547         call.opts.StartCancel();
548       }
549     }
550 
551     status_group_.Update(s);
552   }
553 
554   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
555 };
556 
557 namespace {
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)558 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
559                                 MutableRunGraphRequestWrapper* worker_req,
560                                 size_t index, const string& send_key) {
561   return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
562 }
563 
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)564 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
565                                 MutableRunGraphRequestWrapper* worker_req,
566                                 size_t index, const string& send_key) {
567   return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
568 }
569 
570 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
571 // in-process messages.
572 struct RunCallableResponseWrapper {
573   RunCallableResponse* resp;  // Not owned.
574   std::unordered_map<string, TensorProto> fetch_key_to_protos;
575 
mutable_metadatatensorflow::__anon26c900ce0411::RunCallableResponseWrapper576   RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
577 
AddTensorFromRunGraphResponsetensorflow::__anon26c900ce0411::RunCallableResponseWrapper578   Status AddTensorFromRunGraphResponse(
579       const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
580       size_t index) {
581     // TODO(b/74355905): Add a specialized implementation that avoids
582     // copying the tensor into the RunCallableResponse when at least
583     // two of the {client, master, worker} are in the same process.
584     return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
585   }
586 };
587 }  // namespace
588 
589 template <class FetchListType, class ClientRequestType,
590           class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)591 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
592     const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
593     const FetchListType& fetches, const MasterEnv* env, int64 step_id,
594     int64 execution_count, PerStepState* pss, CallOptions* call_opts,
595     const ClientRequestType& req, ClientResponseType* resp,
596     CancellationManager* cm, bool is_last_partial_run) {
597   // Collect execution cost stats on a smoothly decreasing frequency.
598   ExecutorOpts exec_opts;
599   if (pss->report_tensor_allocations_upon_oom) {
600     exec_opts.set_report_tensor_allocations_upon_oom(true);
601   }
602   if (pss->collect_costs) {
603     exec_opts.set_record_costs(true);
604   }
605   if (pss->collect_timeline) {
606     exec_opts.set_record_timeline(true);
607   }
608   if (pss->collect_rpcs) {
609     SetRPCLogging(true);
610   }
611   if (pss->collect_partition_graphs) {
612     exec_opts.set_record_partition_graphs(true);
613   }
614   if (pss->collect_costs || pss->collect_timeline) {
615     pss->step_stats.resize(partitions_.size());
616   }
617 
618   const int num = partitions_.size();
619   RunManyGraphs calls(num);
620 
621   for (int i = 0; i < num; ++i) {
622     const Part& part = partitions_[i];
623     RunManyGraphs::Call* c = calls.get(i);
624     c->req.reset(part.worker->CreateRunGraphRequest());
625     c->resp.reset(part.worker->CreateRunGraphResponse());
626     if (is_partial_) {
627       c->req->set_is_partial(is_partial_);
628       c->req->set_is_last_partial_run(is_last_partial_run);
629     }
630     c->req->set_session_handle(session_handle_);
631     c->req->set_create_worker_session_called(!should_deregister_);
632     c->req->set_graph_handle(part.graph_handle);
633     c->req->set_step_id(step_id);
634     *c->req->mutable_exec_opts() = exec_opts;
635     c->req->set_store_errors_in_response_body(true);
636     // If any feeds are provided, send the feed values together
637     // in the RunGraph request.
638     // In the partial case, we only want to include feeds provided in the req.
639     // In the non-partial case, all feeds in the request are in the part.
640     // We keep these as separate paths for now, to ensure we aren't
641     // inadvertently slowing down the normal run path.
642     if (is_partial_) {
643       for (const auto& name_index : feeds) {
644         const auto iter = part.feed_key.find(string(name_index.first));
645         if (iter == part.feed_key.end()) {
646           // The provided feed must be for a different partition.
647           continue;
648         }
649         const string& key = iter->second;
650         TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
651                                                     name_index.second, key));
652       }
653       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
654       // For now, we just iterate through partitions to find the matching key.
655       for (const string& req_fetch : fetches) {
656         for (const auto& key_fetch : part.key_fetch) {
657           if (key_fetch.second == req_fetch) {
658             c->req->add_recv_key(key_fetch.first);
659             break;
660           }
661         }
662       }
663     } else {
664       for (const auto& feed_key : part.feed_key) {
665         const string& feed = feed_key.first;
666         const string& key = feed_key.second;
667         auto iter = feeds.find(feed);
668         if (iter == feeds.end()) {
669           return errors::Internal("No feed index found for feed: ", feed);
670         }
671         const int64 feed_index = iter->second;
672         TF_RETURN_IF_ERROR(
673             AddSendFromClientRequest(req, c->req.get(), feed_index, key));
674       }
675       for (const auto& key_fetch : part.key_fetch) {
676         const string& key = key_fetch.first;
677         c->req->add_recv_key(key);
678       }
679     }
680   }
681 
682   // Issues RunGraph calls.
683   for (int i = 0; i < num; ++i) {
684     const Part& part = partitions_[i];
685     RunManyGraphs::Call* call = calls.get(i);
686     TRACEPRINTF("Partition %d %s", i, part.name.c_str());
687     part.worker->RunGraphAsync(
688         &call->opts, call->req.get(), call->resp.get(),
689         std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
690   }
691 
692   // Waits for the RunGraph calls.
693   call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
694   auto token = cm->get_cancellation_token();
695   const bool success =
696       cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
697   if (!success) {
698     calls.StartCancel();
699   }
700   calls.Wait();
701   call_opts->ClearCancelCallback();
702   if (success) {
703     cm->DeregisterCallback(token);
704   } else {
705     return errors::Cancelled("Step was cancelled");
706   }
707   TF_RETURN_IF_ERROR(calls.status());
708 
709   // Collects fetches and metadata.
710   Status status;
711   for (int i = 0; i < num; ++i) {
712     const Part& part = partitions_[i];
713     MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
714     for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
715       auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
716       if (iter == part.key_fetch.end()) {
717         status.Update(errors::Internal("Unexpected fetch key: ",
718                                        run_graph_resp->recv_key(j)));
719         break;
720       }
721       const string& fetch = iter->second;
722       status.Update(
723           resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
724       if (!status.ok()) {
725         break;
726       }
727     }
728     if (pss->collect_timeline) {
729       pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
730     }
731     if (pss->collect_costs) {
732       CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
733       for (int j = 0; j < cost_graph->node_size(); ++j) {
734         resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
735             cost_graph->mutable_node(j));
736       }
737     }
738     if (pss->collect_partition_graphs) {
739       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
740           resp->mutable_metadata()->mutable_partition_graphs();
741       for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
742         partition_graph_defs->Add()->Swap(
743             run_graph_resp->mutable_partition_graph(i));
744       }
745     }
746   }
747   return status;
748 }
749 
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)750 Status MasterSession::ReffedClientGraph::RunPartitions(
751     const MasterEnv* env, int64 step_id, int64 execution_count,
752     PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
753     MutableRunStepResponseWrapper* resp, CancellationManager* cm,
754     const bool is_last_partial_run) {
755   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
756           << execution_count;
757   // Maps the names of fed tensors to their index in `req`.
758   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
759   for (size_t i = 0; i < req.num_feeds(); ++i) {
760     if (!feeds.insert({req.feed_name(i), i}).second) {
761       return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
762     }
763   }
764 
765   std::vector<string> fetches;
766   fetches.reserve(req.num_fetches());
767   for (size_t i = 0; i < req.num_fetches(); ++i) {
768     fetches.push_back(req.fetch_name(i));
769   }
770 
771   return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
772                              call_opts, req, resp, cm, is_last_partial_run);
773 }
774 
RunPartitions(const MasterEnv * env,int64 step_id,int64 execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)775 Status MasterSession::ReffedClientGraph::RunPartitions(
776     const MasterEnv* env, int64 step_id, int64 execution_count,
777     PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
778     RunCallableResponse* resp, CancellationManager* cm) {
779   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
780           << execution_count;
781   // Maps the names of fed tensors to their index in `req`.
782   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
783   for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
784     if (!feeds.insert({callable_opts_.feed(i), i}).second) {
785       // MakeCallable will fail if there are two feeds with the same name.
786       return errors::Internal("Duplicated feeds in callable: ",
787                               callable_opts_.feed(i));
788     }
789   }
790 
791   // Create a wrapped response object to collect the fetched values and
792   // rearrange them for the RunCallableResponse.
793   RunCallableResponseWrapper wrapped_resp;
794   wrapped_resp.resp = resp;
795 
796   TF_RETURN_IF_ERROR(RunPartitionsHelper(
797       feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
798       call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
799 
800   // Collects fetches.
801   // TODO(b/74355905): Add a specialized implementation that avoids
802   // copying the tensor into the RunCallableResponse when at least
803   // two of the {client, master, worker} are in the same process.
804   for (const string& fetch : callable_opts_.fetch()) {
805     TensorProto* fetch_proto = resp->mutable_fetch()->Add();
806     auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
807     if (iter == wrapped_resp.fetch_key_to_protos.end()) {
808       return errors::Internal("Worker did not return a value for fetch: ",
809                               fetch);
810     }
811     fetch_proto->Swap(&iter->second);
812   }
813   return Status::OK();
814 }
815 
816 namespace {
817 
818 class CleanupBroadcastHelper {
819  public:
CleanupBroadcastHelper(int64 step_id,int num_calls,StatusCallback done)820   CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
821       : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
822     req_.set_step_id(step_id);
823   }
824 
825   // Returns a non-owned pointer to a request buffer for all calls.
request()826   CleanupGraphRequest* request() { return &req_; }
827 
828   // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)829   CleanupGraphResponse* response(int i) { return &resps_[i]; }
830 
831   // Called when the ith response is received.
call_done(int i,const Status & s)832   void call_done(int i, const Status& s) {
833     bool run_callback = false;
834     Status status_copy;
835     {
836       mutex_lock l(mu_);
837       status_.Update(s);
838       if (--num_pending_ == 0) {
839         run_callback = true;
840         status_copy = status_;
841       }
842     }
843     if (run_callback) {
844       done_(status_copy);
845       // This is the last call, so delete the helper object.
846       delete this;
847     }
848   }
849 
850  private:
851   // A single request shared between all workers.
852   CleanupGraphRequest req_;
853   // One response buffer for each worker.
854   gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
855 
856   mutex mu_;
857   // Number of requests remaining to be collected.
858   int num_pending_ GUARDED_BY(mu_);
859   // Aggregate status of the operation.
860   Status status_ GUARDED_BY(mu_);
861   // Callback to be called when all operations complete.
862   StatusCallback done_;
863 
864   TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
865 };
866 
867 }  // namespace
868 
CleanupPartitionsAsync(int64 step_id,StatusCallback done)869 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
870     int64 step_id, StatusCallback done) {
871   const int num = partitions_.size();
872   // Helper object will be deleted when the final call completes.
873   CleanupBroadcastHelper* helper =
874       new CleanupBroadcastHelper(step_id, num, std::move(done));
875   for (int i = 0; i < num; ++i) {
876     const Part& part = partitions_[i];
877     part.worker->CleanupGraphAsync(
878         helper->request(), helper->response(i),
879         [helper, i](const Status& s) { helper->call_done(i, s); });
880   }
881 }
882 
ProcessStats(int64 step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)883 void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
884                                                     PerStepState* pss,
885                                                     ProfileHandler* ph,
886                                                     const RunOptions& options,
887                                                     RunMetadata* resp) {
888   if (!pss->collect_costs && !pss->collect_timeline) return;
889 
890   // Out-of-band logging data is collected now, during post-processing.
891   if (pss->collect_timeline) {
892     SetRPCLogging(false);
893     RetrieveLogs(step_id, &pss->rpc_stats);
894   }
895   for (size_t i = 0; i < partitions_.size(); ++i) {
896     const StepStats& ss = pss->step_stats[i];
897     if (ph) {
898       for (const auto& ds : ss.dev_stats()) {
899         ProcessDeviceStats(ph, ds, false /*is_rpc*/);
900       }
901     }
902   }
903   if (ph) {
904     for (const auto& ds : pss->rpc_stats.dev_stats()) {
905       ProcessDeviceStats(ph, ds, true /*is_rpc*/);
906     }
907     ph->StepDone(pss->start_micros, pss->end_micros,
908                  Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
909                  Status::OK());
910   }
911   // Assemble all stats for this timeline into a merged StepStats.
912   if (pss->collect_timeline) {
913     StepStats step_stats_proto;
914     step_stats_proto.Swap(&pss->rpc_stats);
915     for (size_t i = 0; i < partitions_.size(); ++i) {
916       step_stats_proto.MergeFrom(pss->step_stats[i]);
917       pss->step_stats[i].Clear();
918     }
919     pss->step_stats.clear();
920     // Copy the stats back, but only for on-demand profiling to avoid slowing
921     // down calls that trigger the automatic profiling.
922     if (options.trace_level() == RunOptions::FULL_TRACE) {
923       resp->mutable_step_stats()->Swap(&step_stats_proto);
924     } else {
925       // If FULL_TRACE, it can be fetched from Session API, no need for
926       // duplicated publishing.
927       stats_publisher_->PublishStatsProto(step_stats_proto);
928     }
929   }
930 }
931 
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)932 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
933     ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
934   const string& dev_name = ds.device();
935   VLOG(1) << "Device " << dev_name << " reports stats for "
936           << ds.node_stats_size() << " nodes";
937   for (const auto& ns : ds.node_stats()) {
938     if (is_rpc) {
939       // We don't have access to a good Node pointer, so we rely on
940       // sufficient data being present in the NodeExecStats.
941       ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
942                       ns.timeline_label());
943     } else {
944       auto iter = name_to_node_details_.find(ns.node_name());
945       const bool found_node_in_graph = iter != name_to_node_details_.end();
946       if (!found_node_in_graph && ns.timeline_label().empty()) {
947         // The counter incrementing is not thread-safe. But we don't really
948         // care.
949         // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
950         // more general usage.
951         static int log_counter = 0;
952         if (log_counter < 10) {
953           log_counter++;
954           LOG(WARNING) << "Failed to find node " << ns.node_name()
955                        << " for dev " << dev_name;
956         }
957         continue;
958       }
959       const string& optype =
960           found_node_in_graph ? iter->second.type_string : ns.node_name();
961       string details;
962       if (!ns.timeline_label().empty()) {
963         details = ns.timeline_label();
964       } else if (found_node_in_graph) {
965         details = DetailText(iter->second, ns);
966       } else {
967         // Leave details string empty
968       }
969       ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
970                       details);
971     }
972   }
973 }
974 
975 // TODO(suharshs): Merge with CheckFetches in DirectSession.
976 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
977 // on once at setup time to prevent us from computing the dependencies
978 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)979 Status MasterSession::ReffedClientGraph::CheckFetches(
980     const RunStepRequestWrapper& req, const RunState* run_state,
981     GraphExecutionState* execution_state) {
982   // Build the set of pending feeds that we haven't seen.
983   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
984   for (const auto& input : run_state->pending_inputs) {
985     // Skip if already fed.
986     if (input.second) continue;
987     TensorId id(ParseTensorName(input.first));
988     const Node* n = execution_state->get_node_by_name(string(id.first));
989     if (n == nullptr) {
990       return errors::NotFound("Feed ", input.first, ": not found");
991     }
992     pending_feeds.insert(id);
993   }
994   for (size_t i = 0; i < req.num_feeds(); ++i) {
995     const TensorId id(ParseTensorName(req.feed_name(i)));
996     pending_feeds.erase(id);
997   }
998 
999   // Initialize the stack with the fetch nodes.
1000   std::vector<const Node*> stack;
1001   for (size_t i = 0; i < req.num_fetches(); ++i) {
1002     const string& fetch = req.fetch_name(i);
1003     const TensorId id(ParseTensorName(fetch));
1004     const Node* n = execution_state->get_node_by_name(string(id.first));
1005     if (n == nullptr) {
1006       return errors::NotFound("Fetch ", fetch, ": not found");
1007     }
1008     stack.push_back(n);
1009   }
1010 
1011   // Any tensor needed for fetches can't be in pending_feeds.
1012   // We need to use the original full graph from execution state.
1013   const Graph* graph = execution_state->full_graph();
1014   std::vector<bool> visited(graph->num_node_ids(), false);
1015   while (!stack.empty()) {
1016     const Node* n = stack.back();
1017     stack.pop_back();
1018 
1019     for (const Edge* in_edge : n->in_edges()) {
1020       const Node* in_node = in_edge->src();
1021       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1022         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1023                                        in_edge->src_output(),
1024                                        " can't be computed from the feeds"
1025                                        " that have been fed so far.");
1026       }
1027       if (!visited[in_node->id()]) {
1028         visited[in_node->id()] = true;
1029         stack.push_back(in_node);
1030       }
1031     }
1032   }
1033   return Status::OK();
1034 }
1035 
1036 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1037 // result.
DeregisterPartitions()1038 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1039   struct Call {
1040     DeregisterGraphRequest req;
1041     DeregisterGraphResponse resp;
1042   };
1043   for (Part& part : partitions_) {
1044     // The graph handle may be empty if we failed during partition registration.
1045     if (!part.graph_handle.empty()) {
1046       Call* c = new Call;
1047       c->req.set_session_handle(session_handle_);
1048       c->req.set_create_worker_session_called(!should_deregister_);
1049       c->req.set_graph_handle(part.graph_handle);
1050       // NOTE(mrry): We must capture `worker_cache_` since `this`
1051       // could be deleted before the callback is called.
1052       WorkerCacheInterface* worker_cache = worker_cache_;
1053       const string name = part.name;
1054       WorkerInterface* w = part.worker;
1055       CHECK_NOTNULL(w);
1056       auto cb = [worker_cache, c, name, w](const Status& s) {
1057         if (!s.ok()) {
1058           // This error is potentially benign, so we don't log at the
1059           // error level.
1060           LOG(INFO) << "DeregisterGraph error: " << s;
1061         }
1062         delete c;
1063         worker_cache->ReleaseWorker(name, w);
1064       };
1065       w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1066     }
1067   }
1068 }
1069 
1070 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1071 void CopyAndSortStrings(size_t size,
1072                         const std::function<string(size_t)>& input_accessor,
1073                         protobuf::RepeatedPtrField<string>* output) {
1074   std::vector<string> temp;
1075   temp.reserve(size);
1076   for (size_t i = 0; i < size; ++i) {
1077     output->Add(input_accessor(i));
1078   }
1079   std::sort(output->begin(), output->end());
1080 }
1081 }  // namespace
1082 
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1083 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1084                             const ConfigProto& config,
1085                             BuildGraphOptions* opts) {
1086   CallableOptions* callable_opts = &opts->callable_options;
1087   CopyAndSortStrings(
1088       req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1089       callable_opts->mutable_feed());
1090   CopyAndSortStrings(
1091       req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1092       callable_opts->mutable_fetch());
1093   CopyAndSortStrings(
1094       req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1095       callable_opts->mutable_target());
1096 
1097   if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1098     *callable_opts->mutable_run_options()->mutable_debug_options() =
1099         req.options().debug_options();
1100   }
1101 
1102   opts->collective_graph_key =
1103       req.options().experimental().collective_graph_key();
1104   if (config.experimental().collective_deterministic_sequential_execution()) {
1105     opts->collective_order = GraphCollectiveOrder::kEdges;
1106   } else if (config.experimental().collective_nccl()) {
1107     opts->collective_order = GraphCollectiveOrder::kAttrs;
1108   }
1109 }
1110 
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1111 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1112                             BuildGraphOptions* opts) {
1113   CallableOptions* callable_opts = &opts->callable_options;
1114   CopyAndSortStrings(
1115       req.feed_size(), [&req](size_t i) { return req.feed(i); },
1116       callable_opts->mutable_feed());
1117   CopyAndSortStrings(
1118       req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1119       callable_opts->mutable_fetch());
1120   CopyAndSortStrings(
1121       req.target_size(), [&req](size_t i) { return req.target(i); },
1122       callable_opts->mutable_target());
1123 
1124   // TODO(cais): Add TFDBG support to partial runs.
1125 }
1126 
HashBuildGraphOptions(const BuildGraphOptions & opts)1127 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1128   uint64 h = 0x2b992ddfa23249d6ull;
1129   for (const string& name : opts.callable_options.feed()) {
1130     h = Hash64(name.c_str(), name.size(), h);
1131   }
1132   for (const string& name : opts.callable_options.target()) {
1133     h = Hash64(name.c_str(), name.size(), h);
1134   }
1135   for (const string& name : opts.callable_options.fetch()) {
1136     h = Hash64(name.c_str(), name.size(), h);
1137   }
1138 
1139   const DebugOptions& debug_options =
1140       opts.callable_options.run_options().debug_options();
1141   if (!debug_options.debug_tensor_watch_opts().empty()) {
1142     const string watch_summary =
1143         SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1144     h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1145   }
1146 
1147   return h;
1148 }
1149 
BuildGraphOptionsString(const BuildGraphOptions & opts)1150 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1151   string buf;
1152   for (const string& name : opts.callable_options.feed()) {
1153     strings::StrAppend(&buf, " FdE: ", name);
1154   }
1155   strings::StrAppend(&buf, "\n");
1156   for (const string& name : opts.callable_options.target()) {
1157     strings::StrAppend(&buf, " TN: ", name);
1158   }
1159   strings::StrAppend(&buf, "\n");
1160   for (const string& name : opts.callable_options.fetch()) {
1161     strings::StrAppend(&buf, " FeE: ", name);
1162   }
1163   if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1164     strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1165   }
1166   strings::StrAppend(&buf, "\n");
1167   return buf;
1168 }
1169 
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1170 MasterSession::MasterSession(
1171     const SessionOptions& opt, const MasterEnv* env,
1172     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1173     std::unique_ptr<WorkerCacheInterface> worker_cache,
1174     std::unique_ptr<DeviceSet> device_set,
1175     std::vector<string> filtered_worker_list,
1176     StatsPublisherFactory stats_publisher_factory)
1177     : session_opts_(opt),
1178       env_(env),
1179       handle_(strings::FpToString(random::New64())),
1180       remote_devs_(std::move(remote_devs)),
1181       worker_cache_(std::move(worker_cache)),
1182       devices_(std::move(device_set)),
1183       filtered_worker_list_(std::move(filtered_worker_list)),
1184       stats_publisher_factory_(std::move(stats_publisher_factory)),
1185       graph_version_(0),
1186       run_graphs_(5),
1187       partial_run_graphs_(5) {
1188   UpdateLastAccessTime();
1189   CHECK(devices_) << "device_set was null!";
1190 
1191   VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1192           << " #remote " << remote_devs_->size();
1193 
1194   LOG(INFO) << "Start master session " << handle_
1195             << " with config: " << session_opts_.config.ShortDebugString();
1196 }
1197 
~MasterSession()1198 MasterSession::~MasterSession() {
1199   for (const auto& iter : run_graphs_) iter.second->Unref();
1200   for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1201 }
1202 
UpdateLastAccessTime()1203 void MasterSession::UpdateLastAccessTime() {
1204   last_access_time_usec_.store(Env::Default()->NowMicros());
1205 }
1206 
Create(GraphDef * graph_def,const WorkerCacheFactoryOptions & options)1207 Status MasterSession::Create(GraphDef* graph_def,
1208                              const WorkerCacheFactoryOptions& options) {
1209   if (session_opts_.config.use_per_session_threads() ||
1210       session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1211     return errors::InvalidArgument(
1212         "Distributed session does not support session thread pool options.");
1213   }
1214   if (session_opts_.config.graph_options().place_pruned_graph()) {
1215     // TODO(b/29900832): Fix this or remove the option.
1216     LOG(WARNING) << "Distributed session does not support the "
1217                     "place_pruned_graph option.";
1218     session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1219   }
1220 
1221   GraphExecutionStateOptions execution_options;
1222   execution_options.device_set = devices_.get();
1223   execution_options.session_options = &session_opts_;
1224   {
1225     mutex_lock l(mu_);
1226     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1227         graph_def, execution_options, &execution_state_));
1228   }
1229   should_delete_worker_sessions_ = true;
1230   return CreateWorkerSessions(options);
1231 }
1232 
CreateWorkerSessions(const WorkerCacheFactoryOptions & options)1233 Status MasterSession::CreateWorkerSessions(
1234     const WorkerCacheFactoryOptions& options) {
1235   const std::vector<string> worker_names = filtered_worker_list_;
1236   WorkerCacheInterface* worker_cache = get_worker_cache();
1237 
1238   struct WorkerGroup {
1239     // The worker name. (Not owned.)
1240     const string* name;
1241 
1242     // The worker referenced by name. (Not owned.)
1243     WorkerInterface* worker = nullptr;
1244 
1245     // Request and responses used for a given worker.
1246     CreateWorkerSessionRequest request;
1247     CreateWorkerSessionResponse response;
1248     Status status = Status::OK();
1249   };
1250   BlockingCounter done(worker_names.size());
1251   std::vector<WorkerGroup> workers(worker_names.size());
1252 
1253   // Release the workers.
1254   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1255     for (auto&& worker_group : workers) {
1256       if (worker_group.worker != nullptr) {
1257         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1258       }
1259     }
1260   });
1261 
1262   Status status = Status::OK();
1263   // Create all the workers & kick off the computations.
1264   for (size_t i = 0; i < worker_names.size(); ++i) {
1265     workers[i].name = &worker_names[i];
1266     workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
1267     workers[i].request.set_session_handle(handle_);
1268 
1269     DeviceNameUtils::ParsedName name;
1270     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1271       status = errors::Internal("Could not parse name ", worker_names[i]);
1272       LOG(WARNING) << status;
1273       return status;
1274     }
1275     if (!name.has_job || !name.has_task) {
1276       status = errors::Internal("Incomplete worker name ", worker_names[i]);
1277       LOG(WARNING) << status;
1278       return status;
1279     }
1280 
1281     if (options.cluster_def) {
1282       *workers[i].request.mutable_server_def()->mutable_cluster() =
1283           *options.cluster_def;
1284       workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1285       workers[i].request.mutable_server_def()->set_job_name(name.job);
1286       workers[i].request.mutable_server_def()->set_task_index(name.task);
1287       // Session state is always isolated when ClusterSpec propagation
1288       // is in use.
1289       workers[i].request.set_isolate_session_state(true);
1290     } else {
1291       // NOTE(mrry): Do not set any component of the ServerDef,
1292       // because the worker will use its local configuration.
1293       workers[i].request.set_isolate_session_state(
1294           session_opts_.config.isolate_session_state());
1295     }
1296   }
1297 
1298   for (size_t i = 0; i < worker_names.size(); ++i) {
1299     auto cb = [i, &workers, &done](const Status& s) {
1300       workers[i].status = s;
1301       done.DecrementCount();
1302     };
1303     workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1304                                                 &workers[i].response, cb);
1305   }
1306 
1307   done.Wait();
1308   for (size_t i = 0; i < workers.size(); ++i) {
1309     status.Update(workers[i].status);
1310   }
1311   return status;
1312 }
1313 
DeleteWorkerSessions()1314 Status MasterSession::DeleteWorkerSessions() {
1315   WorkerCacheInterface* worker_cache = get_worker_cache();
1316   const std::vector<string>& worker_names = filtered_worker_list_;
1317 
1318   struct WorkerGroup {
1319     // The worker name. (Not owned.)
1320     const string* name;
1321 
1322     // The worker referenced by name. (Not owned.)
1323     WorkerInterface* worker = nullptr;
1324 
1325     CallOptions call_opts;
1326 
1327     // Request and responses used for a given worker.
1328     DeleteWorkerSessionRequest request;
1329     DeleteWorkerSessionResponse response;
1330     Status status = Status::OK();
1331   };
1332   BlockingCounter done(worker_names.size());
1333   std::vector<WorkerGroup> workers(worker_names.size());
1334 
1335   // Release the workers.
1336   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1337     for (auto&& worker_group : workers) {
1338       if (worker_group.worker != nullptr) {
1339         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1340       }
1341     }
1342   });
1343 
1344   Status status = Status::OK();
1345   // Create all the workers & kick off the computations.
1346   for (size_t i = 0; i < worker_names.size(); ++i) {
1347     workers[i].name = &worker_names[i];
1348     workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
1349     workers[i].request.set_session_handle(handle_);
1350     // Since the worker may have gone away, set a timeout to avoid blocking the
1351     // session-close operation.
1352     workers[i].call_opts.SetTimeout(10000);
1353   }
1354 
1355   for (size_t i = 0; i < worker_names.size(); ++i) {
1356     auto cb = [i, &workers, &done](const Status& s) {
1357       workers[i].status = s;
1358       done.DecrementCount();
1359     };
1360     workers[i].worker->DeleteWorkerSessionAsync(
1361         &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1362   }
1363 
1364   done.Wait();
1365   for (size_t i = 0; i < workers.size(); ++i) {
1366     status.Update(workers[i].status);
1367   }
1368   return status;
1369 }
1370 
ListDevices(ListDevicesResponse * resp) const1371 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1372   if (worker_cache_) {
1373     // This is a ClusterSpec-propagated session, and thus env_->local_devices
1374     // are invalid.
1375 
1376     // Mark the "client_device" as the sole local device.
1377     const Device* client_device = devices_->client_device();
1378     for (const Device* dev : devices_->devices()) {
1379       if (dev != client_device) {
1380         *(resp->add_remote_device()) = dev->attributes();
1381       }
1382     }
1383     *(resp->add_local_device()) = client_device->attributes();
1384   } else {
1385     for (Device* dev : env_->local_devices) {
1386       *(resp->add_local_device()) = dev->attributes();
1387     }
1388     for (auto&& dev : *remote_devs_) {
1389       *(resp->add_local_device()) = dev->attributes();
1390     }
1391   }
1392   return Status::OK();
1393 }
1394 
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1395 Status MasterSession::Extend(const ExtendSessionRequest* req,
1396                              ExtendSessionResponse* resp) {
1397   UpdateLastAccessTime();
1398   std::unique_ptr<GraphExecutionState> extended_execution_state;
1399   {
1400     mutex_lock l(mu_);
1401     if (closed_) {
1402       return errors::FailedPrecondition("Session is closed.");
1403     }
1404 
1405     if (graph_version_ != req->current_graph_version()) {
1406       return errors::Aborted("Current version is ", graph_version_,
1407                              " but caller expected ",
1408                              req->current_graph_version(), ".");
1409     }
1410 
1411     CHECK(execution_state_);
1412     TF_RETURN_IF_ERROR(
1413         execution_state_->Extend(req->graph_def(), &extended_execution_state));
1414 
1415     CHECK(extended_execution_state);
1416     // The old execution state will be released outside the lock.
1417     execution_state_.swap(extended_execution_state);
1418     ++graph_version_;
1419     resp->set_new_graph_version(graph_version_);
1420   }
1421   return Status::OK();
1422 }
1423 
get_worker_cache() const1424 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1425   if (worker_cache_) {
1426     return worker_cache_.get();
1427   }
1428   return env_->worker_cache;
1429 }
1430 
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64 * out_count)1431 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1432                                 ReffedClientGraph** out_rcg, int64* out_count) {
1433   const uint64 hash = HashBuildGraphOptions(opts);
1434   {
1435     mutex_lock l(mu_);
1436     // TODO(suharshs): We cache partial run graphs and run graphs separately
1437     // because there is preprocessing that needs to only be run for partial
1438     // run calls.
1439     RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1440     auto iter = m->find(hash);
1441     if (iter == m->end()) {
1442       // We have not seen this subgraph before. Build the subgraph and
1443       // cache it.
1444       VLOG(1) << "Unseen hash " << hash << " for "
1445               << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1446               << "\n";
1447       std::unique_ptr<ClientGraph> client_graph;
1448       TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1449       WorkerCacheInterface* worker_cache = get_worker_cache();
1450       auto entry = new ReffedClientGraph(
1451           handle_, opts, std::move(client_graph), session_opts_,
1452           stats_publisher_factory_, is_partial, worker_cache,
1453           !should_delete_worker_sessions_);
1454       iter = m->insert({hash, entry}).first;
1455       VLOG(1) << "Preparing to execute new graph";
1456     }
1457     *out_rcg = iter->second;
1458     (*out_rcg)->Ref();
1459     *out_count = (*out_rcg)->get_and_increment_execution_count();
1460   }
1461   return Status::OK();
1462 }
1463 
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1464 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1465                                    RCGMap* rcg_map) {
1466   VLOG(1) << "Discarding all reffed graphs";
1467   for (auto p : *rcg_map) {
1468     ReffedClientGraph* rcg = p.second;
1469     if (to_unref) {
1470       to_unref->push_back(rcg);
1471     } else {
1472       rcg->Unref();
1473     }
1474   }
1475   rcg_map->clear();
1476 }
1477 
NewStepId(int64 graph_key)1478 uint64 MasterSession::NewStepId(int64 graph_key) {
1479   if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1480     // StepId must leave the most-significant 7 bits empty for future use.
1481     return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1482   } else {
1483     uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1484     int32 retry_count = 0;
1485     while (step_id == CollectiveExecutor::kInvalidId) {
1486       Notification note;
1487       Status status;
1488       env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1489           graph_key, [&status, &note](const Status& s) {
1490             status = s;
1491             note.Notify();
1492           });
1493       note.WaitForNotification();
1494       if (!status.ok()) {
1495         LOG(ERROR) << "Bad status from "
1496                       "collective_executor_mgr->RefreshStepIdSequence: "
1497                    << status << ".  Retrying.";
1498         int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1499         Env::Default()->SleepForMicroseconds(delay_micros);
1500       } else {
1501         step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1502       }
1503     }
1504     return step_id;
1505   }
1506 }
1507 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1508 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1509                                       PartialRunSetupResponse* resp) {
1510   std::vector<string> inputs, outputs, targets;
1511   for (const auto& feed : req->feed()) {
1512     inputs.push_back(feed);
1513   }
1514   for (const auto& fetch : req->fetch()) {
1515     outputs.push_back(fetch);
1516   }
1517   for (const auto& target : req->target()) {
1518     targets.push_back(target);
1519   }
1520 
1521   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1522 
1523   ReffedClientGraph* rcg = nullptr;
1524 
1525   // Prepare.
1526   BuildGraphOptions opts;
1527   BuildBuildGraphOptions(*req, &opts);
1528   int64 count = 0;
1529   TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1530 
1531   rcg->Ref();
1532   RunState* run_state =
1533       new RunState(inputs, outputs, rcg,
1534                    NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1535   {
1536     mutex_lock l(mu_);
1537     partial_runs_.emplace(
1538         std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1539   }
1540 
1541   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1542 
1543   resp->set_partial_run_handle(handle);
1544   return Status::OK();
1545 }
1546 
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1547 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1548                           MutableRunStepResponseWrapper* resp) {
1549   UpdateLastAccessTime();
1550   {
1551     mutex_lock l(mu_);
1552     if (closed_) {
1553       return errors::FailedPrecondition("Session is closed.");
1554     }
1555     ++num_running_;
1556     // Note: all code paths must eventually call MarkRunCompletion()
1557     // in order to appropriate decrement the num_running_ counter.
1558   }
1559   Status status;
1560   if (!req.partial_run_handle().empty()) {
1561     status = DoPartialRun(opts, req, resp);
1562   } else {
1563     status = DoRunWithLocalExecution(opts, req, resp);
1564   }
1565   return status;
1566 }
1567 
1568 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1569 void MasterSession::MarkRunCompletion() {
1570   mutex_lock l(mu_);
1571   --num_running_;
1572   if (num_running_ == 0) {
1573     num_running_is_zero_.notify_all();
1574   }
1575 }
1576 
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1577 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1578   // Registers subgraphs if haven't done so.
1579   PartitionOptions popts;
1580   popts.node_to_loc = SplitByWorker;
1581   // The closures popts.{new_name,get_incarnation} are called synchronously in
1582   // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1583   // "this" alive during the closure.
1584   popts.new_name = [this](const string& prefix) {
1585     mutex_lock l(mu_);
1586     return strings::StrCat(prefix, "_S", next_node_id_++);
1587   };
1588   popts.get_incarnation = [this](const string& name) -> int64 {
1589     Device* d = devices_->FindDeviceByName(name);
1590     if (d == nullptr) {
1591       return PartitionOptions::kIllegalIncarnation;
1592     } else {
1593       return d->attributes().incarnation();
1594     }
1595   };
1596   popts.control_flow_added = false;
1597   const bool enable_bfloat16_sendrecv =
1598       session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1599   popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1600     if (e->IsControlEdge()) {
1601       return DT_FLOAT;
1602     }
1603     DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1604     if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1605       return DT_BFLOAT16;
1606     } else {
1607       return dtype;
1608     }
1609   };
1610   if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1611     popts.scheduling_for_recvs = true;
1612     popts.need_to_record_start_times = true;
1613   }
1614 
1615   TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1616 
1617   return Status::OK();
1618 }
1619 
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1620 Status MasterSession::DoPartialRun(CallOptions* opts,
1621                                    const RunStepRequestWrapper& req,
1622                                    MutableRunStepResponseWrapper* resp) {
1623   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1624   const string& prun_handle = req.partial_run_handle();
1625   RunState* run_state = nullptr;
1626   {
1627     mutex_lock l(mu_);
1628     auto it = partial_runs_.find(prun_handle);
1629     if (it == partial_runs_.end()) {
1630       return errors::InvalidArgument(
1631           "Must run PartialRunSetup before performing partial runs");
1632     }
1633     run_state = it->second.get();
1634   }
1635   // CollectiveOps are not supported in partial runs.
1636   if (req.options().experimental().collective_graph_key() !=
1637       BuildGraphOptions::kNoCollectiveGraphKey) {
1638     return errors::InvalidArgument(
1639         "PartialRun does not support Collective ops.  collective_graph_key "
1640         "must be kNoCollectiveGraphKey.");
1641   }
1642 
1643   // If this is the first partial run, initialize the PerStepState.
1644   if (!run_state->step_started) {
1645     run_state->step_started = true;
1646     PerStepState pss;
1647 
1648     const auto count = run_state->count;
1649     pss.collect_timeline =
1650         req.options().trace_level() == RunOptions::FULL_TRACE;
1651     pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1652     pss.report_tensor_allocations_upon_oom =
1653         req.options().report_tensor_allocations_upon_oom();
1654 
1655     // Build the cost model every 'build_cost_model_every' steps after skipping
1656     // an
1657     // initial 'build_cost_model_after' steps.
1658     const int64 build_cost_model_after =
1659         session_opts_.config.graph_options().build_cost_model_after();
1660     const int64 build_cost_model_every =
1661         session_opts_.config.graph_options().build_cost_model();
1662     pss.collect_costs =
1663         build_cost_model_every > 0 &&
1664         ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1665     pss.collect_partition_graphs = req.options().output_partition_graphs();
1666 
1667     std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1668         run_state->step_id, count, req.options());
1669     if (ph) {
1670       pss.collect_timeline = true;
1671       pss.collect_rpcs = ph->should_collect_rpcs();
1672     }
1673 
1674     run_state->pss = std::move(pss);
1675     run_state->ph = std::move(ph);
1676   }
1677 
1678   // Make sure that this is a new set of feeds that are still pending.
1679   for (size_t i = 0; i < req.num_feeds(); ++i) {
1680     const string& feed = req.feed_name(i);
1681     auto it = run_state->pending_inputs.find(feed);
1682     if (it == run_state->pending_inputs.end()) {
1683       return errors::InvalidArgument(
1684           "The feed ", feed, " was not specified in partial_run_setup.");
1685     } else if (it->second) {
1686       return errors::InvalidArgument("The feed ", feed,
1687                                      " has already been fed.");
1688     }
1689   }
1690   // Check that this is a new set of fetches that are still pending.
1691   for (size_t i = 0; i < req.num_fetches(); ++i) {
1692     const string& fetch = req.fetch_name(i);
1693     auto it = run_state->pending_outputs.find(fetch);
1694     if (it == run_state->pending_outputs.end()) {
1695       return errors::InvalidArgument(
1696           "The fetch ", fetch, " was not specified in partial_run_setup.");
1697     } else if (it->second) {
1698       return errors::InvalidArgument("The fetch ", fetch,
1699                                      " has already been fetched.");
1700     }
1701   }
1702 
1703   // Ensure that the requested fetches can be computed from the provided feeds.
1704   {
1705     mutex_lock l(mu_);
1706     TF_RETURN_IF_ERROR(
1707         run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1708   }
1709 
1710   // Determine if this partial run satisfies all the pending inputs and outputs.
1711   for (size_t i = 0; i < req.num_feeds(); ++i) {
1712     auto it = run_state->pending_inputs.find(req.feed_name(i));
1713     it->second = true;
1714   }
1715   for (size_t i = 0; i < req.num_fetches(); ++i) {
1716     auto it = run_state->pending_outputs.find(req.fetch_name(i));
1717     it->second = true;
1718   }
1719   bool is_last_partial_run = run_state->PendingDone();
1720 
1721   Status s = run_state->rcg->RunPartitions(
1722       env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1723       resp, &cancellation_manager_, is_last_partial_run);
1724 
1725   // Delete the run state if there is an error or all fetches are done.
1726   if (!s.ok() || is_last_partial_run) {
1727     ReffedClientGraph* rcg = run_state->rcg;
1728     run_state->pss.end_micros = Env::Default()->NowMicros();
1729     // Schedule post-processing and cleanup to be done asynchronously.
1730     Ref();
1731     rcg->Ref();
1732     rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1733                       req.options(), resp->mutable_metadata());
1734     cleanup.release();  // MarkRunCompletion called in done closure.
1735     rcg->CleanupPartitionsAsync(
1736         run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1737           if (!s.ok()) {
1738             LOG(ERROR) << "Cleanup partition error: " << s;
1739           }
1740           rcg->Unref();
1741           MarkRunCompletion();
1742           Unref();
1743         });
1744     mutex_lock l(mu_);
1745     partial_runs_.erase(prun_handle);
1746   }
1747   return s;
1748 }
1749 
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64 rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1750 Status MasterSession::CreateDebuggerState(
1751     const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1752     int64 rcg_execution_count,
1753     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1754   TF_RETURN_IF_ERROR(
1755       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1756 
1757   std::vector<string> input_names;
1758   for (size_t i = 0; i < req.num_feeds(); ++i) {
1759     input_names.push_back(req.feed_name(i));
1760   }
1761   std::vector<string> output_names;
1762   for (size_t i = 0; i < req.num_fetches(); ++i) {
1763     output_names.push_back(req.fetch_name(i));
1764   }
1765   std::vector<string> target_names;
1766   for (size_t i = 0; i < req.num_targets(); ++i) {
1767     target_names.push_back(req.target_name(i));
1768   }
1769 
1770   // TODO(cais): We currently use -1 as a dummy value for session run count.
1771   // While this counter value is straightforward to define and obtain for
1772   // DirectSessions, it is less so for non-direct Sessions. Devise a better
1773   // way to get its value when the need arises.
1774   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1775       debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1776       input_names, output_names, target_names));
1777 
1778   return Status::OK();
1779 }
1780 
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64 count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1781 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1782                                      const RunOptions& run_options,
1783                                      uint64 step_id, int64 count,
1784                                      PerStepState* out_pss,
1785                                      std::unique_ptr<ProfileHandler>* out_ph) {
1786   out_pss->collect_timeline =
1787       run_options.trace_level() == RunOptions::FULL_TRACE;
1788   out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1789   out_pss->report_tensor_allocations_upon_oom =
1790       run_options.report_tensor_allocations_upon_oom();
1791   // Build the cost model every 'build_cost_model_every' steps after skipping an
1792   // initial 'build_cost_model_after' steps.
1793   const int64 build_cost_model_after =
1794       session_opts_.config.graph_options().build_cost_model_after();
1795   const int64 build_cost_model_every =
1796       session_opts_.config.graph_options().build_cost_model();
1797   out_pss->collect_costs =
1798       build_cost_model_every > 0 &&
1799       ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1800   out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1801 
1802   *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1803   if (*out_ph) {
1804     out_pss->collect_timeline = true;
1805     out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1806   }
1807 }
1808 
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1809 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1810                                      uint64 step_id,
1811                                      const RunOptions& run_options,
1812                                      PerStepState* pss,
1813                                      const std::unique_ptr<ProfileHandler>& ph,
1814                                      const Status& run_status,
1815                                      RunMetadata* out_run_metadata) {
1816   Status s = run_status;
1817   if (s.ok()) {
1818     pss->end_micros = Env::Default()->NowMicros();
1819     if (rcg->collective_graph_key() !=
1820         BuildGraphOptions::kNoCollectiveGraphKey) {
1821       env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1822                                                   step_id);
1823     }
1824     // Schedule post-processing and cleanup to be done asynchronously.
1825     rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1826   } else if (errors::IsCancelled(s)) {
1827     mutex_lock l(mu_);
1828     if (closed_) {
1829       if (garbage_collected_) {
1830         s = errors::Cancelled(
1831             "Step was cancelled because the session was garbage collected due "
1832             "to inactivity.");
1833       } else {
1834         s = errors::Cancelled(
1835             "Step was cancelled by an explicit call to `Session::Close()`.");
1836       }
1837     }
1838   }
1839   Ref();
1840   rcg->Ref();
1841   rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1842     if (!s.ok()) {
1843       LOG(ERROR) << "Cleanup partition error: " << s;
1844     }
1845     rcg->Unref();
1846     MarkRunCompletion();
1847     Unref();
1848   });
1849   return s;
1850 }
1851 
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1852 Status MasterSession::DoRunWithLocalExecution(
1853     CallOptions* opts, const RunStepRequestWrapper& req,
1854     MutableRunStepResponseWrapper* resp) {
1855   VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1856   PerStepState pss;
1857   pss.start_micros = Env::Default()->NowMicros();
1858   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1859 
1860   // Prepare.
1861   BuildGraphOptions bgopts;
1862   BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1863   ReffedClientGraph* rcg = nullptr;
1864   int64 count;
1865   TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1866 
1867   // Unref "rcg" when out of scope.
1868   core::ScopedUnref unref(rcg);
1869 
1870   std::unique_ptr<DebuggerStateInterface> debugger_state;
1871   const DebugOptions& debug_options = req.options().debug_options();
1872 
1873   if (!debug_options.debug_tensor_watch_opts().empty()) {
1874     TF_RETURN_IF_ERROR(
1875         CreateDebuggerState(debug_options, req, count, &debugger_state));
1876   }
1877   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1878 
1879   // Keeps the highest 8 bits 0x01: we reserve some bits of the
1880   // step_id for future use.
1881   uint64 step_id = NewStepId(rcg->collective_graph_key());
1882   TRACEPRINTF("stepid %llu", step_id);
1883 
1884   std::unique_ptr<ProfileHandler> ph;
1885   FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1886 
1887   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1888                                 &cancellation_manager_, false);
1889 
1890   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
1891   return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
1892                         resp->mutable_metadata());
1893 }
1894 
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)1895 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
1896                                    MakeCallableResponse* resp) {
1897   UpdateLastAccessTime();
1898 
1899   BuildGraphOptions opts;
1900   opts.callable_options = req.options();
1901   opts.use_function_convention = false;
1902 
1903   ReffedClientGraph* callable;
1904 
1905   {
1906     mutex_lock l(mu_);
1907     if (closed_) {
1908       return errors::FailedPrecondition("Session is closed.");
1909     }
1910     std::unique_ptr<ClientGraph> client_graph;
1911     TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1912     callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
1913                                      session_opts_, stats_publisher_factory_,
1914                                      false /* is_partial */, get_worker_cache(),
1915                                      !should_delete_worker_sessions_);
1916   }
1917 
1918   Status s = BuildAndRegisterPartitions(callable);
1919   if (!s.ok()) {
1920     callable->Unref();
1921     return s;
1922   }
1923 
1924   uint64 handle;
1925   {
1926     mutex_lock l(mu_);
1927     handle = next_callable_handle_++;
1928     callables_[handle] = callable;
1929   }
1930 
1931   resp->set_handle(handle);
1932   return Status::OK();
1933 }
1934 
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)1935 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
1936                                     const RunCallableRequest& req,
1937                                     RunCallableResponse* resp) {
1938   VLOG(2) << "DoRunCallable req: " << req.DebugString();
1939   PerStepState pss;
1940   pss.start_micros = Env::Default()->NowMicros();
1941   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1942 
1943   // Prepare.
1944   int64 count = rcg->get_and_increment_execution_count();
1945 
1946   const uint64 step_id = NewStepId(rcg->collective_graph_key());
1947   TRACEPRINTF("stepid %llu", step_id);
1948 
1949   const RunOptions& run_options = rcg->callable_options().run_options();
1950 
1951   if (run_options.timeout_in_ms() != 0) {
1952     opts->SetTimeout(run_options.timeout_in_ms());
1953   }
1954 
1955   std::unique_ptr<ProfileHandler> ph;
1956   FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
1957   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
1958                                 &cancellation_manager_);
1959   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
1960   return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
1961                         resp->mutable_metadata());
1962 }
1963 
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)1964 Status MasterSession::RunCallable(CallOptions* opts,
1965                                   const RunCallableRequest& req,
1966                                   RunCallableResponse* resp) {
1967   UpdateLastAccessTime();
1968   ReffedClientGraph* callable;
1969   {
1970     mutex_lock l(mu_);
1971     if (closed_) {
1972       return errors::FailedPrecondition("Session is closed.");
1973     }
1974     int64 handle = req.handle();
1975     if (handle >= next_callable_handle_) {
1976       return errors::InvalidArgument("No such callable handle: ", handle);
1977     }
1978     auto iter = callables_.find(req.handle());
1979     if (iter == callables_.end()) {
1980       return errors::InvalidArgument(
1981           "Attempted to run callable after handle was released: ", handle);
1982     }
1983     callable = iter->second;
1984     callable->Ref();
1985     ++num_running_;
1986   }
1987   core::ScopedUnref unref_callable(callable);
1988   return DoRunCallable(opts, callable, req, resp);
1989 }
1990 
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)1991 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
1992                                       ReleaseCallableResponse* resp) {
1993   UpdateLastAccessTime();
1994   ReffedClientGraph* to_unref = nullptr;
1995   {
1996     mutex_lock l(mu_);
1997     auto iter = callables_.find(req.handle());
1998     if (iter != callables_.end()) {
1999       to_unref = iter->second;
2000       callables_.erase(iter);
2001     }
2002   }
2003   if (to_unref != nullptr) {
2004     to_unref->Unref();
2005   }
2006   return Status::OK();
2007 }
2008 
Close()2009 Status MasterSession::Close() {
2010   {
2011     mutex_lock l(mu_);
2012     closed_ = true;  // All subsequent calls to Run() or Extend() will fail.
2013   }
2014   cancellation_manager_.StartCancel();
2015   std::vector<ReffedClientGraph*> to_unref;
2016   {
2017     mutex_lock l(mu_);
2018     while (num_running_ != 0) {
2019       num_running_is_zero_.wait(l);
2020     }
2021     ClearRunsTable(&to_unref, &run_graphs_);
2022     ClearRunsTable(&to_unref, &partial_run_graphs_);
2023     ClearRunsTable(&to_unref, &callables_);
2024   }
2025   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2026   if (should_delete_worker_sessions_) {
2027     Status s = DeleteWorkerSessions();
2028     if (!s.ok()) {
2029       LOG(WARNING) << s;
2030     }
2031   }
2032   return Status::OK();
2033 }
2034 
GarbageCollect()2035 void MasterSession::GarbageCollect() {
2036   {
2037     mutex_lock l(mu_);
2038     closed_ = true;
2039     garbage_collected_ = true;
2040   }
2041   cancellation_manager_.StartCancel();
2042   Unref();
2043 }
2044 
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64 count)2045 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2046                                   const std::vector<string>& output_names,
2047                                   ReffedClientGraph* rcg, const uint64 step_id,
2048                                   const int64 count)
2049     : rcg(rcg), step_id(step_id), count(count) {
2050   // Initially all the feeds and fetches are pending.
2051   for (auto& name : input_names) {
2052     pending_inputs[name] = false;
2053   }
2054   for (auto& name : output_names) {
2055     pending_outputs[name] = false;
2056   }
2057 }
2058 
~RunState()2059 MasterSession::RunState::~RunState() {
2060   if (rcg) rcg->Unref();
2061 }
2062 
PendingDone() const2063 bool MasterSession::RunState::PendingDone() const {
2064   for (const auto& it : pending_inputs) {
2065     if (!it.second) return false;
2066   }
2067   for (const auto& it : pending_outputs) {
2068     if (!it.second) return false;
2069   }
2070   return true;
2071 }
2072 
2073 }  // end namespace tensorflow
2074