• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include "tensorflow/core/framework/allocation_description.pb.h"
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_description.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/grappler/clusters/utils.h"
25 #include "tensorflow/core/grappler/costs/utils.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34 
35 namespace tensorflow {
36 namespace grappler {
37 
38 namespace {
39 
40 // Key to the cached _Recv ops map, and its hash and predicate structures.
41 struct RecvNodeDescriptor {
42   const NodeDef* node;
43   const int port_num;
44   const string device;
45 
RecvNodeDescriptortensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptor46   RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
47                      const string& device_)
48       : node(node_), port_num(port_num_), device(device_) {}
49 };
50 
51 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptorHash52   std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
53     return std::hash<const NodeDef*>()(recv_node.node) ^
54            std::hash<int>()(recv_node.port_num) ^
55            std::hash<string>()(recv_node.device);
56   }
57 };
58 
59 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptorEqual60   bool operator()(const RecvNodeDescriptor& a,
61                   const RecvNodeDescriptor& b) const {
62     return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
63   }
64 };
65 }  // namespace
66 
67 // ReadyNodeManager
GetCurrNode()68 const NodeDef* LIFOManager::GetCurrNode() {
69   CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
70   if (curr_pos_ == nodes_.end()) {
71     curr_pos_ = --(nodes_.rbegin().base());  // Last one in the list.
72   }
73   // Once curr_pos_ is set to a valid entry in the list, we keep using the
74   // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
75   // change the GetCurrNode() return value.
76   return *curr_pos_;
77 }
78 
RemoveCurrNode()79 void LIFOManager::RemoveCurrNode() {
80   // Make sure we have curr_pos_ ready to be removed.
81   GetCurrNode();
82   // Note curr_pos_ may not be pointing the last element if some nodes are
83   // added.
84   nodes_.erase(curr_pos_);
85 
86   curr_pos_ = nodes_.end();  // Reset curr_pos_.
87 }
88 
FirstReadyManager()89 FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
90   std::make_heap(nodes_.begin(), nodes_.end());
91 }
92 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)93 void FirstReadyManager::Init(
94     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
95   // Reset the node state since different instances of the scheduler can reuse
96   // the same node_manager.
97   node_state_ = node_state;
98   nodes_.clear();
99   waiting_queue_.clear();
100   greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
101     if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) {
102       // Use Node name as tie-breaker for deterministic node scheduling.
103       return a->name().compare(b->name()) > 0;
104     } else {
105       // Note: we need a node with minimum time_ready, not
106       // maximum; hence, using a > b for comparison function.
107       return node_state_->at(a).time_ready > node_state_->at(b).time_ready;
108     }
109   };
110 }
111 
GetCurrNode()112 const NodeDef* FirstReadyManager::GetCurrNode() {
113   if (nodes_.empty()) {
114     // Nothing in the node_; probably, the very first call. Move
115     // waiting_queue_ to node_.
116     DrainWaitingQueue();
117     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
118   }
119   return nodes_.front();
120 }
121 
RemoveCurrNode()122 void FirstReadyManager::RemoveCurrNode() {
123   if (nodes_.empty()) {
124     // Make sure that there is a node to be removed at the front of nodes_.
125     GetCurrNode();
126   }
127   std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
128   nodes_.pop_back();
129   DrainWaitingQueue();
130 }
131 
Empty() const132 bool FirstReadyManager::Empty() const {
133   return nodes_.empty() && waiting_queue_.empty();
134 }
135 
DrainWaitingQueue()136 void FirstReadyManager::DrainWaitingQueue() {
137   for (const auto* node : waiting_queue_) {
138     // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that
139     // the first element is the node with minimum time_ready.
140     nodes_.push_back(node);
141     std::push_heap(nodes_.begin(), nodes_.end(), greater_);
142   }
143   waiting_queue_.clear();
144 }
145 
CompositeNodeManager()146 CompositeNodeManager::CompositeNodeManager()
147     : ReadyNodeManager(), send_manager_(), recv_manager_() {}
148 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)149 void CompositeNodeManager::Init(
150     const std::unordered_map<const NodeDef*, NodeState>* node_state) {
151   node_state_ = node_state;
152   send_manager_.Init(node_state);
153   recv_manager_.Init(node_state);
154   curr_node_ = nullptr;
155 }
156 
AddNode(const NodeDef * node)157 void CompositeNodeManager::AddNode(const NodeDef* node) {
158   if (IsSend(*node)) {
159     send_manager_.AddNode(node);
160   } else if (IsRecv(*node)) {
161     recv_manager_.AddNode(node);
162   } else {
163     const auto& device = node_state_->at(node).device_name;
164     ops_lifo_map_[device].AddNode(node);
165   }
166 }
167 
GetCurrNode()168 const NodeDef* CompositeNodeManager::GetCurrNode() {
169   if (curr_node_) return curr_node_;
170 
171   // Per-device LIFO for normal ops (not _Send / _Recv),
172   // FirstReady for _Send and _Recv (separately),
173   // Globally (among the LIFO-selected ops from each device and _Send and
174   // _Recv) FirstReady,
175   // Priorty order: _Send, _Recv, and then the rest, if time_ready is equal.
176   std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
177   for (auto& ops_lifo : ops_lifo_map_) {
178     if (!ops_lifo.second.Empty()) {
179       const auto* op = ops_lifo.second.GetCurrNode();
180       candidates.emplace_back(op, node_state_->at(op).time_ready);
181     }
182   }
183   if (!send_manager_.Empty()) {
184     const auto* send = send_manager_.GetCurrNode();
185     candidates.emplace_back(send, node_state_->at(send).time_ready);
186   }
187   if (!recv_manager_.Empty()) {
188     const auto* recv = recv_manager_.GetCurrNode();
189     candidates.emplace_back(recv, node_state_->at(recv).time_ready);
190   }
191   CHECK(!candidates.empty());
192   auto first_ready = std::min_element(
193       candidates.begin(), candidates.end(),
194       [](const std::pair<const NodeDef*, Costs::Duration>& a,
195          const std::pair<const NodeDef*, Costs::Duration>& b) {
196         if (a.second == b.second) {
197           // Note that there can be only 1 Send and only 1 Recv in candidates,
198           // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
199           // normap op, and a_score and b_score are equal only if both are
200           // normal ops.
201           int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
202           int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
203           if (a_score == b_score) {
204             // Both are normal ops; use node name as tie breaker.
205             return a.first->name().compare(b.first->name()) < 0;
206           } else {
207             // Priortize by op type: _Send, _Recv, and normap ops.
208             return a_score > b_score;
209           }
210         } else {
211           return a.second < b.second;
212         }
213       });
214   // Next time we call GetCurrNode(), it just returns the cached one,
215   // curr_node_ until we call RemovCurrNode().
216   curr_node_ = first_ready->first;
217 
218   return curr_node_;
219 }
220 
RemoveCurrNode()221 void CompositeNodeManager::RemoveCurrNode() {
222   const auto* node = GetCurrNode();
223   if (IsSend(*node)) {
224     send_manager_.RemoveCurrNode();
225   } else if (IsRecv(*node)) {
226     recv_manager_.RemoveCurrNode();
227   } else {
228     const auto device = node_state_->at(node).device_name;
229     ops_lifo_map_[device].RemoveCurrNode();
230   }
231   // Reset curr_node_ so that GetCurrNode() finds another node.
232   curr_node_ = nullptr;
233 }
234 
Empty() const235 bool CompositeNodeManager::Empty() const {
236   // Empty if all the ready managers are empty.
237   bool empty = true;
238   for (const auto& ops_lifo : ops_lifo_map_) {
239     empty &= ops_lifo.second.Empty();
240   }
241   return empty && send_manager_.Empty() && recv_manager_.Empty();
242 }
243 
ReadyNodeManagerFactory(const string & ready_node_manager)244 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
245     const string& ready_node_manager) {
246   if (ready_node_manager == "FIFO") {
247     return absl::make_unique<FIFOManager>();
248   } else if (ready_node_manager == "LIFO") {
249     return absl::make_unique<LIFOManager>();
250   } else if (ready_node_manager == "FirstReady") {
251     return absl::make_unique<FirstReadyManager>();
252   } else if (ready_node_manager == "Composite") {
253     return absl::make_unique<CompositeNodeManager>();
254   }
255   LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
256   return nullptr;
257 }
258 
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes)259 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
260                                    const bool use_aggressive_shape_inference,
261                                    Cluster* cluster,
262                                    ReadyNodeManager* ready_nodes)
263     : ready_nodes_(ready_nodes),
264       graph_costs_(Costs::ZeroCosts()),
265       cluster_(cluster),
266       use_static_shapes_(use_static_shapes),
267       use_aggressive_shape_inference_(use_aggressive_shape_inference),
268       placer_(cluster) {
269   graph_costs_.num_ops_total = 0;
270   initialized_ = false;
271   track_mem_usage_snapshot_ = VLOG_IS_ON(1);
272 }
273 
Init(const GrapplerItem * item)274 Status VirtualScheduler::Init(const GrapplerItem* item) {
275   grappler_item_ = item;
276   graph_properties_ = absl::make_unique<GraphProperties>(*item);
277 
278   initialized_ = false;
279 
280   // Clear all internal states so that the VirtualScheduler is reusable for
281   // different GrapplerItems
282   node_map_.clear();
283   device_.clear();
284   additional_nodes_.clear();
285 
286   graph_costs_ = Costs::ZeroCosts();
287   graph_costs_.num_ops_total = 0;
288   op_to_cost_.clear();
289 
290   op_counts_.clear();
291   op_costs_.clear();
292 
293   // Init() preprocesses the input grappler_item and graph_properties to extract
294   // necessary information for emulating tensorflow op scheduling and
295   // construct internal data structures (NodeState and DeviceState) for virtual
296   // scheduling.
297   ready_nodes_->Init(GetNodeStates());
298 
299   // Construct graph properties.
300   if (use_static_shapes_) {
301     TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
302         true, use_aggressive_shape_inference_));
303   } else {
304     TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
305   }
306 
307   const auto& graph = grappler_item_->graph;
308   const auto& fetch_nodes = grappler_item_->fetch;
309   std::set<string> feed_nodes;
310   for (const auto& f : grappler_item_->feed) {
311     auto iter_and_inserted_flag = feed_nodes.insert(f.first);
312     QCHECK(iter_and_inserted_flag.second)
313         << "Duplicate feed node found: " << f.first;
314   }
315 
316   // Get the nodes that would run to output fetch_nodes.
317   bool ill_formed = false;
318   const std::vector<const NodeDef*> fetch_fanin_nodes =
319       ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed);
320   if (ill_formed) {
321     return errors::InvalidArgument(
322         "Ill formed graph or invalid set of fetch nodes specified");
323   }
324 
325   // TODO(dyoon): this is a bit inefficient as name_to_node is already built in
326   // ComputeTransitiveFanin().
327   // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
328   // from the fetch nodes are scheduled. So the scheduled nodes should be
329   // exactly the same as those executed for real. One possible discrepancy could
330   // be the control flow nodes, where tf only executes one path.
331   std::unordered_map<string, const NodeDef*> name_to_node;
332   for (const auto& node : fetch_fanin_nodes) {
333     name_to_node[node->name()] = node;
334   }
335 
336   // Traverses the graph to record _Send nodes.
337   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
338   // to _Recv as control dependency when creating GrapplerItem.
339   std::unordered_map<string, const NodeDef*> name_to_send;
340   for (const auto& node : graph.node()) {
341     if (IsSend(node)) {
342       const auto& attr = node.attr();
343       name_to_send[attr.at("tensor_name").s()] = &node;
344     }
345   }
346 
347   // To reuse _Recv ops.
348   std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
349                      RecvNodeDescriptorEqual>
350       cached_recv_nodes;
351 
352   // Build node_map; for each node, create its NodeState and connect its inputs
353   // and outputs.
354   for (const auto* curr_node : fetch_fanin_nodes) {
355     auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
356     const string curr_node_device = DeviceName(curr_node);
357     std::vector<string> inputs;
358     if (IsRecv(*curr_node)) {
359       const auto& attr = curr_node->attr();
360       if (attr.count("tensor_name")) {
361         const auto& send_node_name = attr.at("tensor_name").s();
362         auto it = name_to_send.find(send_node_name);
363         // If there is a _Send associated with the curr_node (_Recv), add it as
364         // input.
365         if (it != name_to_send.end()) {
366           const NodeDef* send = it->second;
367           inputs = {send->name()};
368         }
369       }
370     } else {
371       for (const string& input : curr_node->input()) {
372         inputs.push_back(input);
373       }
374     }
375     for (const string& input_node_name : inputs) {
376       // Note that input_node_name may be in <prefix><node_name>:<port_num>
377       // format, where <prefix> (e.g., "^" for control dependency) and
378       // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
379       const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
380 
381       CHECK(input_node);
382       const string in_device = DeviceName(input_node);
383       const auto input_node_port_num = NodePosition(input_node_name);
384 
385       if (curr_node_device == in_device) {
386         // Same device: connect input_node and curr_node directly.
387         curr_node_state.inputs.push_back(
388             std::make_pair(input_node, input_node_port_num));
389         auto& input_node_state = GetNodeStateOrCreateIt(input_node);
390         input_node_state.outputs[input_node_port_num].push_back(curr_node);
391       } else {
392         RecvNodeDescriptor recv_node(input_node, input_node_port_num,
393                                      curr_node_device);
394         auto it = cached_recv_nodes.find(recv_node);
395         if (it != cached_recv_nodes.end()) {
396           // Different device, but found an already-cached copy (a _Recv op);
397           // connect the _Recv to curr_node.
398           const NodeDef* recv_op = it->second;
399           // recv_op's output port is hard-coded to zero.
400           curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
401           auto& input_node_state = node_map_.at(recv_op);
402           input_node_state.outputs[0].push_back(curr_node);
403         } else {
404           // Different device, no cached copy; transfer input_node to the
405           // curr_node's device.
406           auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
407                                               input_node_name);
408           // Note that CreateSendRecv() already connected input/output between
409           // _Send and _Recv ops.
410           const auto* send = send_and_recv.first;
411           const auto* recv = send_and_recv.second;
412           // recv_op's output port is hard-coded to zero.
413           curr_node_state.inputs.push_back(std::make_pair(recv, 0));
414           auto& input_node_state = GetNodeStateOrCreateIt(input_node);
415           input_node_state.outputs[input_node_port_num].push_back(send);
416 
417           // Cache the _Recv op for future use.
418           cached_recv_nodes[recv_node] = recv;
419         }
420       }
421     }
422 
423     // Special case: given feed nodes are ready at time 0.
424     const bool given_as_feed =
425         feed_nodes.find(curr_node->name()) != feed_nodes.end();
426 
427     // Default case: node without inputs are ready at time 0.
428     // Note that we check inputs vector which may be different to
429     // curr_node->input(); e.g., we add Send as input to Recv.
430     const bool has_no_inputs = inputs.empty();
431 
432     if (given_as_feed || has_no_inputs) {
433       curr_node_state.time_ready = Costs::Duration();
434       ready_nodes_->AddNode(curr_node);
435       VLOG(3) << "Added ready node: " << curr_node->name();
436     }
437 
438     feed_nodes.erase(curr_node->name());
439 
440     if (IsPersistentNode(curr_node)) {
441       auto& device_state = device_[curr_node_device];
442       for (int port_num = 0;
443            port_num < curr_node_state.output_properties.size(); ++port_num) {
444         device_state.persistent_nodes.insert(
445             std::make_pair(curr_node, port_num));
446       }
447     }
448   }
449 
450   if (ready_nodes_->Empty()) {
451     return errors::InvalidArgument("No ready nodes in the graph.");
452   }
453 
454   if (!feed_nodes.empty()) {
455     // This isn't always a bug: when the caller hasn't specified the exact list
456     // of feed and fetch nodes, by default we consider all placeholders as feed
457     // nodes, but some of them may not be needed for the default fetch node.
458     VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
459             << str_util::Join(feed_nodes, ",");
460   }
461 
462   initialized_ = true;
463   return Status::OK();
464 }
465 
MaybeUpdateInputOutput(const NodeDef * node)466 void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
467   CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
468   // This method is called when NodeState is created and adds input and output
469   // properties for a few exceptional cases that GraphProperties cannot provide
470   // input/output properties.
471   if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
472     // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
473     // attr; normal _Send and _Recv ops (from the input graph) do not have that
474     // attr.
475     auto& node_state = node_map_[node];
476     auto& inputs = node_state.input_properties;
477     auto& outputs = node_state.output_properties;
478 
479     // _Send and _Recv ops are created from VirtualScheduler, so
480     // there should be no inputs TensorProperties.
481     CHECK(inputs.empty());
482     CHECK(outputs.empty());
483     const auto& attr = node->attr();
484     // This is the original input source to the _Send and _Recv, and this
485     // string includes "^" if it was control dependency, and output port
486     /// (e.g., ":2") if the input source had multiple outputs.
487     const auto& input_source_name = attr.at(kAttrInputSrc).s();
488     if (IsControlInput(input_source_name)) {
489       // Control dependency; regardless of the input source tensor size,
490       // send 4B.
491       OpInfo::TensorProperties control_message;
492       control_message.set_dtype(DT_FLOAT);
493       control_message.mutable_shape()->add_dim()->set_size(1);
494       auto* value = control_message.mutable_value();
495       value->add_float_val(1);
496       inputs.push_back(control_message);
497       outputs.push_back(control_message);
498     } else {
499       const auto& output_properties =
500           graph_properties_->GetOutputProperties(NodeName(input_source_name));
501       // Like with HasInputProperties, if a node does not have output
502       // properties, it's likely it was pruned during the shape inference run.
503       if (!output_properties.empty()) {
504         const auto input_node_port_num = NodePosition(input_source_name);
505         // Use the input source's output property as _Send and _Recv's input
506         // property.
507         CHECK_GT(output_properties.size(), input_node_port_num);
508         inputs.push_back(output_properties[input_node_port_num]);
509         outputs.push_back(output_properties[input_node_port_num]);
510       }
511     }
512   }
513 }
514 
Round2(const float x) const515 float VirtualScheduler::Round2(const float x) const {
516   // Not using std::round from <cmath> here because not all platforms seem to
517   // support that (specifically Android).
518   return ::round(100.0 * x) / 100.0;
519 }
520 
IsPersistentNode(const NodeDef * node) const521 bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
522   // Variables are persistent nodes.
523   return IsVariable(*node);
524 }
525 
DeviceName(const NodeDef * node) const526 string VirtualScheduler::DeviceName(const NodeDef* node) const {
527   return placer_.get_canonical_device_name(*node);
528 }
529 
SanitizedDeviceName(const NodeDef * node) const530 string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
531   // Replace the ":" characters that may be present in the device name with "_".
532   // This makes it possible to then use the resulting string in a node name.
533   return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
534                                  "_", true);
535 }
536 
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const537 string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
538                                            const NodeDef* to) const {
539   CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
540   return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" +
541          SanitizedDeviceName(to);
542 }
543 
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name)544 std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
545     const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
546     const string& input_name) {
547   CHECK(!initialized_) << "CreateSendRecv is called after Init().";
548 
549   // Connect "from" node to "to" node with _Send and _Recv such that
550   // from -> _Send -> _Recv -> to.
551   // _Send is placed on "Channel" device, and _Recv is on the same device
552   // as "to" node.
553   // input_node_name is the string from the "to" node to identify which output
554   // we get from the "from" node.
555 
556   // Note that we use NodeState for scheduling, so _Send and _Recv
557   // NodeDefs created here need not be correct: in terms of name,
558   // input names, attrs, etc.
559 
560   auto input_node_port_num = NodePosition(input_name);
561   string src_name;
562   if (input_node_port_num >= 0) {
563     src_name = strings::StrCat(from->name(), "_", input_node_port_num);
564   } else {
565     src_name = strings::StrCat(from->name(), "_minus1");
566   }
567 
568   // _Send op.
569   auto* send = new NodeDef();
570   send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
571                  "_to_" + SanitizedDeviceName(to));
572   send->set_op("_Send");
573   send->add_input(from->name());
574   send->set_device(ChannelDeviceName(from, to));
575   auto& send_attr = *(send->mutable_attr());
576   send_attr[kAttrInputSrc].set_s(input_name);
577   send_attr[kAttrSrcDevice].set_s(DeviceName(from));
578   send_attr[kAttrDstDevice].set_s(DeviceName(to));
579   // GraphDef generated by AutoGrappler has tensor_name field when removing
580   // _Send/_Recv nodes.
581   if (input_node->attr().count(kAttrTensorName)) {
582     send_attr[kAttrTensorName].set_s(
583         input_node->attr().at(kAttrTensorName).s());
584   }
585 
586   // _Recv op.
587   auto* recv = new NodeDef();
588   recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
589   recv->set_op("_Recv");
590   recv->add_input(send->name());
591   recv->set_device(DeviceName(to));
592   auto& recv_attr = *(recv->mutable_attr());
593   recv_attr[kAttrInputSrc].set_s(input_name);
594   if (input_node->attr().count(kAttrTensorName)) {
595     recv_attr[kAttrTensorName].set_s(
596         input_node->attr().at(kAttrTensorName).s());
597   }
598 
599   // NodeState for _Send op.
600   auto& send_node_state = GetNodeStateOrCreateIt(send);
601   send_node_state.device_name = send->device();  // Set Channel device.
602   send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
603   send_node_state.outputs[0].push_back(recv);
604 
605   // NodeState for _Recv op.
606   auto& recv_node_state = GetNodeStateOrCreateIt(recv);
607   recv_node_state.inputs.push_back(std::make_pair(send, 0));
608   recv_node_state.outputs[0].push_back(to);
609 
610   // Keep the created nodes.
611   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
612   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
613 
614   // Return _Send and _Recv.
615   return std::make_pair(send, recv);
616 }
617 
GetCurrNode() const618 OpContext VirtualScheduler::GetCurrNode() const {
619   const NodeDef* node = ready_nodes_->GetCurrNode();
620 
621   // Get the device from the placer.
622   DeviceProperties device;
623   device = placer_.get_device(*node);
624 
625   // Special case for _Send op.
626   if (IsSend(*node)) {
627     device.set_type(kChannelDevice);
628   }
629 
630   // Construct OpContext.
631   OpContext op_context;
632   const auto& node_state = node_map_.at(node);
633   op_context.name = node->name();
634   op_context.device_name = node_state.device_name;
635   auto& op_info = op_context.op_info;
636   op_info.set_op(node->op());
637   *op_info.mutable_attr() = node->attr();
638   for (auto& input : node_state.input_properties) {
639     *op_info.add_inputs() = input;
640   }
641   for (auto& output : node_state.output_properties) {
642     *op_info.add_outputs() = output;
643   }
644   op_info.mutable_device()->Swap(&device);
645 
646   if (grappler_item_->graph.has_library()) {
647     op_context.function_library = &grappler_item_->graph.library();
648   }
649   return op_context;
650 }
651 
GetNodeStateOrCreateIt(const NodeDef * node)652 NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
653   CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
654 
655   auto it = node_map_.find(node);
656   if (it != node_map_.end()) {
657     return it->second;
658   }
659 
660   // Not found; create a NodeState for this node.
661   it = node_map_.emplace(node, NodeState()).first;
662   auto& node_state = it->second;
663   node_state.input_properties =
664       graph_properties_->GetInputProperties(node->name());
665   node_state.output_properties =
666       graph_properties_->GetOutputProperties(node->name());
667 
668   // Some ops may need further processing to the input / output properties:
669   // _Send and _Recv.
670   MaybeUpdateInputOutput(node);
671 
672   if (!IsSend(*node)) {
673     node_state.device_name = DeviceName(node);
674     // For _Send op, device_name will be set to Channel in CreateSendRecv().
675   }
676 
677   // Initialize output port related data:
678   // Assume the size of OutputProperties represents the number of output ports
679   // of this node.
680   for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
681     node_state.time_no_references[i] = Costs::Duration::max();
682     node_state.num_outputs_executed[i] = 0;
683     // Populate an empty vector for each port. The caller will add nodes
684     // that use this port as input.
685     node_state.outputs[i] = {};
686   }
687   // Port_num -1 is for control dependency.
688   node_state.time_no_references[-1] = Costs::Duration::max();
689   node_state.num_outputs_executed[-1] = 0;
690   node_state.outputs[-1] = {};
691 
692   return it->second;
693 }
694 
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)695 Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
696                                           std::map<string, Costs>* op_cost) {
697   auto it = op_cost->find(op_name);
698   if (it == op_cost->end()) {
699     // Note that default constructor of Costs sets some memory related fields
700     // to unknown values so we should explicitly initialize it with ZeroCosts.
701     it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
702   }
703   return it->second;
704 }
705 
AddOutputNodesToReadyQueue(const NodeDef * node,const Costs::Duration & curr_time)706 void VirtualScheduler::AddOutputNodesToReadyQueue(
707     const NodeDef* node, const Costs::Duration& curr_time) {
708   // Checks whether the Switch's output slots change over iterations.
709   int slot = -1;
710   if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
711       node->attr().at(kOutputSlots).list().i_size() > 0) {
712     slot = node->attr().at(kOutputSlots).list().i(0);
713     for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
714       if (slot != node->attr().at(kOutputSlots).list().i(i)) {
715         slot = -1;
716         break;
717       }
718     }
719   }
720 
721   // Increment num_inputs_ready of the output nodes and maybe add to ready
722   // nodes.
723   auto& node_state = node_map_[node];
724   for (const auto& port_num_output_pair : node_state.outputs) {
725     // If Switch is annotated and its output slots are always the same, we only
726     // schedule the slot that was executed. Otherwise, scheduler both slots.
727     if (slot >= 0 && port_num_output_pair.first != slot) continue;
728 
729     for (auto* output_node : port_num_output_pair.second) {
730       auto& output_state = node_map_[output_node];
731       output_state.num_inputs_ready++;
732       // Execute a node as soon as all its inputs are ready. Merge nodes are
733       // special since they run as soon as one of their inputs becomes
734       // available.
735       if (output_state.num_inputs_ready == output_state.inputs.size() ||
736           IsMerge(*output_node)) {
737         // This output node is now ready.
738         output_state.time_ready = curr_time;
739         ready_nodes_->AddNode(output_node);
740         VLOG(3) << "  Add output: " << output_node->name();
741       }
742     }
743   }
744 }
745 
MarkCurrNodeExecuted(const Costs & node_costs)746 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
747   // Update graph_costs_ and per-op costs.
748   const NodeDef* node = ready_nodes_->GetCurrNode();
749   auto& node_state = node_map_[node];
750   // If there is annotation in the graph about execution times, we use that
751   // number, otherwise, we assume the node is executed once.
752   node_state.execution_count = node->attr().count(kExecutionCount) == 0
753                                    ? 1
754                                    : node->attr().at(kExecutionCount).i();
755   Costs total_node_costs =
756       MultiplyCosts(node_costs, node_state.execution_count);
757   graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
758   const string& op_name = node->op();
759 
760   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
761   op_cost = CombineCosts(op_cost, total_node_costs);
762 
763   if (VLOG_IS_ON(2)) {
764     // Also keep track of op counts and costs per op (with their shapes).
765     OpContext op_context = GetCurrNode();
766 
767     string node_description = GetOpDescription(op_context.op_info);
768     op_counts_[node_description] += 1;
769     op_costs_[node_description] =
770         std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
771                        !node_costs.inaccurate);
772   }
773 
774   // Update node and device states.
775   auto& device = device_[node_state.device_name];
776   device.nodes_executed.push_back(node);
777   // Node is scheduled when the device is available AND all the inputs are
778   // ready; hence, time_scheduled is time_ready if time_ready > device curr
779   // time.
780   node_state.time_scheduled =
781       std::max(device.GetCurrTime(), node_state.time_ready);
782   // Override device curr time with the time_scheduled.
783   device.device_costs.execution_time = node_state.time_scheduled;
784   device.device_costs = CombineCosts(device.device_costs, total_node_costs);
785   auto curr_time = device.GetCurrTime();
786   node_state.time_finished = curr_time;
787 
788   // Update device memory usage.
789   if (!IsPersistentNode(node)) {
790     for (const auto& port_num_output_pair : node_state.outputs) {
791       int port_num = port_num_output_pair.first;
792       // There's a chance that a specific output is not used at all.
793       if (node_state.outputs[port_num].empty()) {
794         node_state.time_no_references[port_num] = curr_time;
795       } else {
796         device.memory_usage +=
797             CalculateOutputSize(node_state.output_properties, port_num) *
798             node_state.execution_count;
799         device.nodes_in_memory.insert(std::make_pair(node, port_num));
800       }
801     }
802   }
803 
804   // Update device's per-op cost.
805   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
806   device_op_cost = CombineCosts(device_op_cost, total_node_costs);
807 
808   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
809           << ", device: " << node->device()
810           << ", execution_count: " << node_state.execution_count
811           << ", ready: " << node_state.time_ready.count()
812           << ", scheduled: " << node_state.time_scheduled.count()
813           << ", finished: " << node_state.time_finished.count();
814 
815   // Checks outputs, and adds ready nodes to queue.
816   AddOutputNodesToReadyQueue(node, curr_time);
817 
818   // Increment num_outputs_executed of the input nodes and maybe update memory.
819   for (const auto& input_port : node_state.inputs) {
820     auto* input = input_port.first;
821     auto port = input_port.second;
822     auto& input_state = node_map_[input];
823     input_state.num_outputs_executed[port]++;
824     if (input_state.num_outputs_executed[port] ==
825             input_state.outputs[port].size() &&
826         !IsPersistentNode(input)) {
827       // All the outputs are executed; no reference to this output port of
828       // input node.
829       input_state.time_no_references[port] = curr_time;
830       auto& input_device = device_[input_state.device_name];
831       input_device.memory_usage -=
832           CalculateOutputSize(input_state.output_properties, port) *
833           node_state.execution_count;
834 
835       input_device.nodes_in_memory.erase(std::make_pair(input, port));
836     }
837   }
838 
839   if (!IsPersistentNode(node)) {
840     // Now that output memory is added and used up nodes are deallocated,
841     // check max memory usage.
842     if (device.memory_usage > device.max_memory_usage) {
843       device.max_memory_usage = device.memory_usage;
844 
845       if (track_mem_usage_snapshot_) {
846         device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
847       }
848     }
849   }
850 
851   ready_nodes_->RemoveCurrNode();
852 
853   return !ready_nodes_->Empty();
854 }
855 
Summary() const856 Costs VirtualScheduler::Summary() const {
857   // Overall statement about accuracy
858   VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
859           << graph_costs_.num_ops_with_unknown_shapes
860           << " having unknown shapes";
861 
862   // Print out basic execution summary.
863   VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
864   VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
865   VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
866   VLOG(1) << "Expected intermediate memory time: "
867           << graph_costs_.intermediate_memory_time.count();
868   VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
869   VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
870   VLOG(1) << "Expected max per-op streaming buffers: "
871           << graph_costs_.max_per_op_streaming;
872 
873   VLOG(1) << "Per-op execution time / compute time / memory time"
874           << " / intermediate memory time:";
875   for (const auto& op_cost_pair : op_to_cost_) {
876     const auto& op = op_cost_pair.first;
877     const auto& cost = op_cost_pair.second.execution_time.count();
878     const auto& compute_cost = op_cost_pair.second.compute_time.count();
879     const auto& memory_cost = op_cost_pair.second.memory_time.count();
880     const auto& intermediate_memory_cost =
881         op_cost_pair.second.intermediate_memory_time.count();
882     const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
883     if (cost) {  // Skip printing out zero-cost ops.
884       VLOG(1) << strings::Printf(
885           " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
886           (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
887           static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
888           static_cast<int64>(intermediate_memory_cost));
889     }
890   }
891 
892   // Print per device summary
893   VLOG(1) << "Devices:";
894   Costs critical_path_costs = Costs::ZeroCosts();
895   std::vector<string> device_names;
896   device_names.reserve(device_.size());
897   for (auto& it : device_) {
898     device_names.push_back(it.first);
899   }
900   std::sort(device_names.begin(), device_names.end());
901 
902   for (const auto& name : device_names) {
903     const auto& state = device_.at(name);
904 
905     std::map<string, int64> op_to_memory;
906     // First profile only persistent memory usage.
907     int64 persistent_memory_usage = 0;
908     std::set<string> persisent_ops;
909     for (const auto& node_port : state.persistent_nodes) {
910       const auto* node = node_port.first;
911       const auto port = node_port.second;
912       const auto output_size =
913           CalculateOutputSize(node_map_.at(node).output_properties, port);
914       persistent_memory_usage += output_size;
915       op_to_memory[node->op()] += output_size;
916       persisent_ops.insert(node->op());
917     }
918     int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
919     critical_path_costs.estimated_max_memory_per_device[name] =
920         max_memory_usage;
921 
922     const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
923     VLOG(1) << "Device = " << name
924             << ", num_nodes = " << state.nodes_executed.size()
925             << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
926             << "persistent = "
927             << strings::HumanReadableNumBytes(persistent_memory_usage)
928             << ", peak = "
929             << strings::HumanReadableNumBytes(state.max_memory_usage)
930             << ", total = " << strings::HumanReadableNumBytes(max_memory_usage)
931             << ", at the end: "
932             << strings::HumanReadableNumBytes(state.memory_usage);
933 
934     // Overall statement about accuracy
935     VLOG(1) << state.device_costs.num_ops_total
936             << " ops processed in total, with "
937             << state.device_costs.num_ops_with_unknown_shapes
938             << " having unknown shapes";
939 
940     VLOG(1) << "Per-op execution time / compute time / memory time "
941             << " / intermediate memory time"
942             << " (and memory usage at peak memory usage):";
943 
944     // Profile non-persistent op memory usage.
945     for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
946       const auto* node = node_port.first;
947       const auto port = node_port.second;
948       op_to_memory[node->op()] +=
949           CalculateOutputSize(node_map_.at(node).output_properties, port);
950     }
951     Costs::NanoSeconds total_compute_time_ns;
952     bool is_total_cost_accurate = true;
953     for (const auto& op_cost_pair : state.op_to_cost) {
954       const auto& op = op_cost_pair.first;
955       const auto& cost = op_cost_pair.second.execution_time.count();
956       const auto& compute_cost = op_cost_pair.second.compute_time.count();
957       const auto& memory_cost = op_cost_pair.second.memory_time.count();
958       const auto& intermediate_memory_cost =
959           op_cost_pair.second.intermediate_memory_time.count();
960       total_compute_time_ns += op_cost_pair.second.execution_time;
961       const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
962       if (!is_op_cost_accurate) {
963         is_total_cost_accurate = false;
964       }
965 
966       int64 op_mem_usage = 0;
967       auto it = op_to_memory.find(op);
968       if (it != op_to_memory.end()) {
969         op_mem_usage = it->second;
970       }
971 
972       const float mem_usage_percent =
973           max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
974                                : 0.0;
975       if (cost || mem_usage_percent > 1.0) {
976         // Print out only non-zero cost ops or ops with > 1% memory usage.
977         VLOG(1) << strings::Printf(
978                        " + %30s : %c %10lld / %10lld / %10lld / %10lld",
979                        op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
980                        static_cast<int64>(cost),
981                        static_cast<int64>(compute_cost),
982                        static_cast<int64>(memory_cost),
983                        static_cast<int64>(intermediate_memory_cost))
984                 << " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
985                 << mem_usage_percent << "%] "
986                 << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
987       }
988     }
989 
990     int utilization = 0;
991     if (wall_time_ns.count() > 0) {
992       utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
993     }
994     VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
995             << (is_total_cost_accurate ? "" : "~")
996             << total_compute_time_ns.count()
997             << ", utilization = " << utilization << "%";
998 
999     if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1000       critical_path_costs = state.device_costs;
1001     }
1002   }
1003 
1004   if (VLOG_IS_ON(2)) {
1005     // Also log the op description and their corresponding counts.
1006     VLOG(2) << "Node description, counts, cost:";
1007     for (const auto& item : op_counts_) {
1008       int cost;
1009       bool is_cost_accurate;
1010       std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1011       VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1012               << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1013               << " us";
1014     }
1015   }
1016 
1017   VLOG(1) << "Critical path execution time: "
1018           << critical_path_costs.execution_time.count();
1019   return critical_path_costs;
1020 }
1021 
Summary(RunMetadata * metadata)1022 Costs VirtualScheduler::Summary(RunMetadata* metadata) {
1023   if (metadata) GenerateRunMetadata(metadata);
1024   return Summary();
1025 }
1026 
GenerateRunMetadata(RunMetadata * metadata)1027 void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
1028   // Fill RunMetadata's step_stats and partition_graphs fields.
1029   StepStats* stepstats = metadata->mutable_step_stats();
1030   for (const auto& device : device_) {
1031     GraphDef* device_partition_graph = metadata->add_partition_graphs();
1032     DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1033     device_stepstats->set_device(device.first);
1034     for (const auto& node_def : device.second.nodes_executed) {
1035       const NodeState& nodestate = node_map_.at(node_def);
1036       NodeExecStats* node_stats = device_stepstats->add_node_stats();
1037       uint64 total_output_size = 0;
1038       for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
1039         const auto& properties = nodestate.output_properties[slot];
1040         NodeOutput* no = node_stats->add_output();
1041         no->set_slot(slot);
1042         TensorDescription* tensor_descr = no->mutable_tensor_description();
1043         tensor_descr->set_dtype(properties.dtype());
1044         *tensor_descr->mutable_shape() = properties.shape();
1045         // Optional allocation description.
1046         const auto tensor_size =
1047             CalculateOutputSize(nodestate.output_properties, slot);
1048         total_output_size += tensor_size;
1049         tensor_descr->mutable_allocation_description()->set_requested_bytes(
1050             tensor_size);
1051         tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1052             tensor_size);
1053       }
1054       node_stats->set_timeline_label(node_def->op());
1055       node_stats->set_node_name(node_def->name());
1056       node_stats->set_op_start_rel_micros(0);
1057       node_stats->set_all_start_micros(
1058           nodestate.time_scheduled.asMicroSeconds().count());
1059       node_stats->set_op_end_rel_micros(
1060           nodestate.time_finished.asMicroSeconds().count() -
1061           nodestate.time_scheduled.asMicroSeconds().count());
1062       node_stats->set_all_end_rel_micros(
1063           nodestate.time_finished.asMicroSeconds().count() -
1064           nodestate.time_scheduled.asMicroSeconds().count());
1065       auto* mem_stats = node_stats->mutable_memory_stats();
1066       // VirtualScheduler does not specify scratch pad memory usage.
1067       mem_stats->set_temp_memory_size(0);
1068       int64 persistent_memory_size = 0;
1069       if (IsPersistentNode(node_def)) {
1070         persistent_memory_size = total_output_size;
1071       }
1072       mem_stats->set_persistent_memory_size(persistent_memory_size);
1073       *device_partition_graph->add_node() = *node_def;
1074     }
1075   }
1076 }
1077 
GetPeakMemoryUsage() const1078 const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
1079     const {
1080   std::unordered_map<string, int64> result;
1081   for (const auto& device : device_) {
1082     const string& name = device.first;
1083     const DeviceState& state = device.second;
1084     result[name] = state.max_memory_usage;
1085   }
1086   return result;
1087 }
1088 
1089 }  // end namespace grappler
1090 }  // end namespace tensorflow
1091