1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_replace.h"
20 #include "tensorflow/core/framework/allocation_description.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/utils.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/utils.h"
30 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
39 const char kAttrInputSrc[] = "input_source_";
40 const char kAttrSrcDevice[] = "send_device";
41 const char kAttrDstDevice[] = "recv_device";
42 const char kAttrTensorName[] = "tensor_name";
43 const char kChannelDevice[] = "Channel";
44 const char kStreaming[] = "_streaming";
45
46 namespace {
47
48 using ::tensorflow::strings::HumanReadableNumBytes;
49
Round2(const float x)50 float Round2(const float x) {
51 // Not using std::round from <cmath> here because not all platforms seem to
52 // support that (specifically Android).
53 return ::round(100.0 * x) / 100.0;
54 }
55
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)56 Costs& FindOrCreateZero(const string& op_name,
57 std::map<string, Costs>* op_cost) {
58 auto it = op_cost->find(op_name);
59 if (it == op_cost->end()) {
60 // Note that default constructor of Costs sets some memory related fields
61 // to unknown values so we should explicitly initialize it with ZeroCosts.
62 it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
63 }
64 return it->second;
65 }
66
67 // Key to the cached _Recv ops map, and its hash and predicate structures.
68 struct RecvNodeDescriptor {
69 const NodeDef* node;
70 const int port_num;
71 const string device;
72
RecvNodeDescriptortensorflow::grappler::__anon9eba1db40111::RecvNodeDescriptor73 RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
74 const string& device_)
75 : node(node_), port_num(port_num_), device(device_) {}
76 };
77
78 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anon9eba1db40111::RecvNodeDescriptorHash79 std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
80 return std::hash<const NodeDef*>()(recv_node.node) ^
81 std::hash<int>()(recv_node.port_num) ^
82 std::hash<string>()(recv_node.device);
83 }
84 };
85
86 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anon9eba1db40111::RecvNodeDescriptorEqual87 bool operator()(const RecvNodeDescriptor& a,
88 const RecvNodeDescriptor& b) const {
89 return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
90 }
91 };
92
UpdateDeviceAnnotationState(const NodeDef * node,const NodeState & node_state,DeviceState * device)93 void UpdateDeviceAnnotationState(const NodeDef* node,
94 const NodeState& node_state,
95 DeviceState* device) {
96 if (node->attr().count(kOutputShapes) == 0) return;
97
98 int64 execution_count = node->attr().count(kExecutionCount) == 0
99 ? 1
100 : node->attr().at(kExecutionCount).i();
101
102 auto& shape_annotation_stats = device->shape_annotation_stats;
103 shape_annotation_stats.num_ops_annotated += 1;
104 shape_annotation_stats.num_ops_executed += execution_count;
105 shape_annotation_stats.num_ops_executed_more_than_once +=
106 execution_count > 1 ? 1 : 0;
107 shape_annotation_stats.num_ops_with_incompatible_shapes +=
108 node_state.shape_incompatible ? 1 : 0;
109 shape_annotation_stats.num_ops_with_dynamic_shapes +=
110 (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0;
111 }
112
IsStreamingPort(const NodeDef & node,const int port)113 bool IsStreamingPort(const NodeDef& node, const int port) {
114 if (!node.attr().contains(kStreaming)) return false;
115
116 auto& attr_list = node.attr().at(kStreaming).list();
117 bool is_streaming_port = false;
118 if (port >= 0 && port < attr_list.b().size()) {
119 is_streaming_port = attr_list.b(port);
120 }
121 return is_streaming_port;
122 }
123
124 } // namespace
125
AddNode(const NodeDef * node)126 void LIFOManager::AddNode(const NodeDef* node) {
127 // Merge nodes are scheduled with the lowest priority in LIFO manager; virtual
128 // scheduler may run multiple input nodes of Merge (when we don't have
129 // annotation, which is quite common); simply scheduling Merge after one of
130 // its input may break scheduling constraints; some inputs of Merge may be
131 // scheduled after the Merge. So, we place Merge at the beginning of the queue
132 // to guarantee all the inputs of Merge are scheduled before the Merge.
133 if (IsMerge(*node)) {
134 nodes_.push_front(node);
135 } else {
136 nodes_.push_back(node);
137 }
138 }
139
GetCurrNode()140 const NodeDef* LIFOManager::GetCurrNode() {
141 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
142 if (curr_pos_ == nodes_.end()) {
143 curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list.
144 }
145 // Once curr_pos_ is set to a valid entry in the list, we keep using the
146 // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
147 // change the GetCurrNode() return value.
148 return *curr_pos_;
149 }
150
RemoveCurrNode()151 void LIFOManager::RemoveCurrNode() {
152 // Make sure we have curr_pos_ ready to be removed.
153 GetCurrNode();
154 // Note curr_pos_ may not be pointing the last element if some nodes are
155 // added.
156 nodes_.erase(curr_pos_);
157
158 curr_pos_ = nodes_.end(); // Reset curr_pos_.
159 }
160
HeapReadyManager()161 HeapReadyManager::HeapReadyManager() : ReadyNodeManager() {
162 std::make_heap(nodes_.begin(), nodes_.end());
163 }
164
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)165 Status HeapReadyManager::Init(
166 const std::unordered_map<const NodeDef*, NodeState>* node_map) {
167 // Resets the node state since different instances of the scheduler can reuse
168 // the same node_manager.
169 node_map_ = node_map;
170 nodes_.clear();
171 curr_node_ = nullptr;
172
173 // Sets up the comparator for the heap.
174 greater_ = Greater();
175
176 return Status::OK();
177 }
178
AddNode(const NodeDef * node)179 void HeapReadyManager::AddNode(const NodeDef* node) {
180 // push_heap in AddNode and pop_heap in RemoveCurrNode() guarantees that the
181 // first element is the node with minimum time_ready.
182 nodes_.push_back(node);
183 std::push_heap(nodes_.begin(), nodes_.end(), greater_);
184 }
185
GetCurrNode()186 const NodeDef* HeapReadyManager::GetCurrNode() {
187 if (curr_node_) return curr_node_;
188 if (nodes_.empty()) {
189 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
190 }
191 const std::string node_name = nodes_.front()->name();
192 // Next time we call GetCurrNode(), it just returns the cached copy
193 // curr_node_, until we call the RemoveCurrNode().
194 curr_node_ = nodes_.front();
195 // Remove current node from the heap immediately. Because if we wait until
196 // later, the heap could have gotten re-organized if AddNode is called. The
197 // current node is anyways cached, incase GetCurrNode() is called again.
198 std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
199 nodes_.pop_back();
200 return curr_node_;
201 }
202
RemoveCurrNode()203 void HeapReadyManager::RemoveCurrNode() {
204 if (curr_node_) {
205 // If cached copy exists, remove that.
206 // Reset curr_node_ so that GetCurrNode() finds another node.
207 curr_node_ = nullptr;
208 } else {
209 // If cached copy not present, then remove entry from the heap queue.
210 std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
211 nodes_.pop_back();
212 }
213 }
214
Empty() const215 bool HeapReadyManager::Empty() const {
216 return nodes_.empty() && curr_node_ == nullptr;
217 }
218
FirstReadyCmp(const std::unordered_map<const NodeDef *,NodeState> * node_map,const NodeDef * a,const NodeDef * b)219 bool FirstReadyCmp(
220 const std::unordered_map<const NodeDef*, NodeState>* node_map,
221 const NodeDef* a, const NodeDef* b) {
222 if (node_map->at(a).time_ready == node_map->at(b).time_ready) {
223 // Use Node name as tie-breaker for deterministic node scheduling.
224 return a->name().compare(b->name()) > 0;
225 } else {
226 // Note: we need a node with minimum time_ready, not maximum; hence, using
227 // a > b for comparison function.
228 return node_map->at(a).time_ready > node_map->at(b).time_ready;
229 }
230 }
231
232 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()233 FirstReadyManager::Greater() {
234 auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
235 return FirstReadyCmp(node_map_, a, b);
236 };
237 return greater;
238 }
239
240 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()241 PriorityReadyManager::Greater() {
242 auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
243 auto pri_a = node_priority_.at(a->name());
244 auto pri_b = node_priority_.at(b->name());
245 if (pri_a == pri_b) {
246 // Fallback to default (FirstReady) behaviour.
247 return FirstReadyCmp(node_map_, a, b);
248 }
249 return pri_a > pri_b;
250 };
251 return greater;
252 }
253
AddNode(const NodeDef * node)254 void PriorityReadyManager::AddNode(const NodeDef* node) {
255 if (node_priority_.count(node->name()) == 0) {
256 VLOG(3) << "Priority of node " << node->name() << " not found.";
257 node_priority_[node->name()] = 0;
258 }
259 HeapReadyManager::AddNode(node);
260 }
261
SetPriority(const std::unordered_map<string,int> & node_priority)262 Status PriorityReadyManager::SetPriority(
263 const std::unordered_map<string, int>& node_priority) {
264 node_priority_ = node_priority;
265 return Status::OK();
266 }
267
CompositeNodeManager()268 CompositeNodeManager::CompositeNodeManager()
269 : ReadyNodeManager(), send_manager_(), recv_manager_() {}
270
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)271 Status CompositeNodeManager::Init(
272 const std::unordered_map<const NodeDef*, NodeState>* node_map) {
273 node_map_ = node_map;
274 TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
275 TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
276 curr_node_ = nullptr;
277 return Status::OK();
278 }
279
AddNode(const NodeDef * node)280 void CompositeNodeManager::AddNode(const NodeDef* node) {
281 if (IsSend(*node)) {
282 send_manager_.AddNode(node);
283 } else if (IsRecv(*node)) {
284 recv_manager_.AddNode(node);
285 } else {
286 const auto& device = node_map_->at(node).device_name;
287 ops_lifo_map_[device].AddNode(node);
288 }
289 }
290
GetCurrNode()291 const NodeDef* CompositeNodeManager::GetCurrNode() {
292 if (curr_node_) return curr_node_;
293
294 // Per-device LIFO for normal ops (not _Send / _Recv),
295 // FirstReady for _Send and _Recv (separately),
296 // Globally (among the LIFO-selected ops from each device and _Send and
297 // _Recv) FirstReady,
298 // Priority order: _Send, _Recv, and then the rest, if time_ready is equal.
299 std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
300 for (auto& ops_lifo : ops_lifo_map_) {
301 if (!ops_lifo.second.Empty()) {
302 const auto* op = ops_lifo.second.GetCurrNode();
303 candidates.emplace_back(op, node_map_->at(op).time_ready);
304 }
305 }
306 if (!send_manager_.Empty()) {
307 const auto* send = send_manager_.GetCurrNode();
308 candidates.emplace_back(send, node_map_->at(send).time_ready);
309 }
310 if (!recv_manager_.Empty()) {
311 const auto* recv = recv_manager_.GetCurrNode();
312 candidates.emplace_back(recv, node_map_->at(recv).time_ready);
313 }
314 CHECK(!candidates.empty());
315 auto first_ready = std::min_element(
316 candidates.begin(), candidates.end(),
317 [](const std::pair<const NodeDef*, Costs::Duration>& a,
318 const std::pair<const NodeDef*, Costs::Duration>& b) {
319 if (a.second == b.second) {
320 // Note that there can be only 1 Send and only 1 Recv in candidates,
321 // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
322 // normap op, and a_score and b_score are equal only if both are
323 // normal ops.
324 int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
325 int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
326 if (a_score == b_score) {
327 // Both are normal ops; use node name as tie breaker.
328 return a.first->name().compare(b.first->name()) < 0;
329 } else {
330 // Prioritize by op type: _Send, _Recv, and normap ops.
331 return a_score > b_score;
332 }
333 } else {
334 return a.second < b.second;
335 }
336 });
337 // Next time we call GetCurrNode(), it just returns the cached one,
338 // curr_node_ until we call RemovCurrNode().
339 curr_node_ = first_ready->first;
340
341 return curr_node_;
342 }
343
RemoveCurrNode()344 void CompositeNodeManager::RemoveCurrNode() {
345 const auto* node = GetCurrNode();
346 if (IsSend(*node)) {
347 send_manager_.RemoveCurrNode();
348 } else if (IsRecv(*node)) {
349 recv_manager_.RemoveCurrNode();
350 } else {
351 const auto device = node_map_->at(node).device_name;
352 ops_lifo_map_[device].RemoveCurrNode();
353 }
354 // Reset curr_node_ so that GetCurrNode() finds another node.
355 curr_node_ = nullptr;
356 }
357
Empty() const358 bool CompositeNodeManager::Empty() const {
359 // Empty if all the ready managers are empty.
360 bool empty = true;
361 for (const auto& ops_lifo : ops_lifo_map_) {
362 empty &= ops_lifo.second.Empty();
363 }
364 return empty && send_manager_.Empty() && recv_manager_.Empty();
365 }
366
ReadyNodeManagerFactory(const string & ready_node_manager)367 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
368 const string& ready_node_manager) {
369 if (ready_node_manager == "FIFO") {
370 return absl::make_unique<FIFOManager>();
371 } else if (ready_node_manager == "LIFO") {
372 return absl::make_unique<LIFOManager>();
373 } else if (ready_node_manager == "FirstReady") {
374 return absl::make_unique<FirstReadyManager>();
375 } else if (ready_node_manager == "Composite") {
376 return absl::make_unique<CompositeNodeManager>();
377 }
378 LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
379 return nullptr;
380 }
381
~SchedulerState()382 SchedulerState::~SchedulerState() {}
383
SchedulerState(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,std::unique_ptr<VirtualPlacer> placer)384 SchedulerState::SchedulerState(const bool use_static_shapes,
385 const bool use_aggressive_shape_inference,
386 Cluster* cluster,
387 std::unique_ptr<VirtualPlacer> placer)
388 : graph_costs_(Costs::ZeroCosts()),
389 cluster_(cluster),
390 use_static_shapes_(use_static_shapes),
391 use_aggressive_shape_inference_(use_aggressive_shape_inference),
392 placer_(std::move(placer)) {
393 DCHECK(placer_); // check if the pointer is valid.
394 graph_costs_.num_ops_total = 0;
395 initialized_ = false;
396 track_mem_usage_snapshot_ = VLOG_IS_ON(1);
397 }
398
Init(const GrapplerItem * item,std::vector<const NodeDef * > * initial_nodes,bool create_explicit_channel_device)399 Status SchedulerState::Init(const GrapplerItem* item,
400 std::vector<const NodeDef*>* initial_nodes,
401 bool create_explicit_channel_device) {
402 initialized_ = false;
403
404 // Clear all internal states so that the SchedulerState is reusable for
405 // different GrapplerItems
406 node_map_.clear();
407 device_.clear();
408 additional_nodes_.clear();
409
410 graph_costs_ = Costs::ZeroCosts();
411 graph_costs_.num_ops_total = 0;
412 op_to_cost_.clear();
413
414 op_counts_.clear();
415 op_costs_.clear();
416
417 initial_nodes->clear();
418
419 // Constructs graph properties and performs shape inference.
420 graph_properties_ = absl::make_unique<GraphProperties>(*item);
421 // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
422 // to get rid of use_static_shapes_ and cluster_.
423 if (use_static_shapes_) {
424 TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
425 true, use_aggressive_shape_inference_, true));
426 } else {
427 TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
428 }
429
430 grappler_item_ = item;
431 const auto& graph = grappler_item_->graph;
432 const auto& fetch_nodes = grappler_item_->fetch;
433 std::set<string> feed_nodes;
434
435 for (const auto& f : grappler_item_->feed) {
436 auto iter_and_inserted_flag = feed_nodes.insert(f.first);
437 QCHECK(iter_and_inserted_flag.second)
438 << "Duplicate feed node found: " << f.first;
439 }
440
441 // Get the nodes that would run to output fetch_nodes.
442 std::unordered_map<string, const NodeDef*> name_to_node;
443 std::vector<const NodeDef*> fetch_fanin_nodes;
444 TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
445 &fetch_fanin_nodes));
446
447 // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
448 // from the fetch nodes are scheduled. So the scheduled nodes should be
449 // exactly the same as those executed for real. One possible discrepancy could
450 // be the control flow nodes, where tf only executes one path.
451
452 // Traverses the graph to record _Send nodes.
453 // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
454 // to _Recv as control dependency when creating GrapplerItem.
455 std::unordered_map<string, const NodeDef*> name_to_send;
456 for (const auto& node : graph.node()) {
457 if (IsSend(node)) {
458 const auto& attr = node.attr();
459 name_to_send[attr.at("tensor_name").s()] = &node;
460 }
461 }
462
463 // To reuse _Recv ops.
464 std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
465 RecvNodeDescriptorEqual>
466 cached_recv_nodes;
467
468 // Build node_map; for each node, create its NodeState and connect its inputs
469 // and outputs.
470 for (const auto* curr_node : fetch_fanin_nodes) {
471 auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
472 const string curr_node_device = DeviceName(curr_node);
473 std::vector<string> inputs;
474 if (IsRecv(*curr_node)) {
475 const auto& attr = curr_node->attr();
476 if (attr.count("tensor_name")) {
477 const auto& send_node_name = attr.at("tensor_name").s();
478 auto it = name_to_send.find(send_node_name);
479 // If there is a _Send associated with the curr_node (_Recv), add it as
480 // input.
481 if (it != name_to_send.end()) {
482 const NodeDef* send = it->second;
483 inputs = {send->name()};
484 }
485 }
486 } else {
487 for (const string& input : curr_node->input()) {
488 inputs.push_back(input);
489 }
490 }
491 for (const string& input_node_name : inputs) {
492 // Note that input_node_name may be in <prefix><node_name>:<port_num>
493 // format, where <prefix> (e.g., "^" for control dependency) and
494 // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
495 const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
496
497 CHECK(input_node);
498 const string in_device = DeviceName(input_node);
499 const auto input_node_port_num = NodePosition(input_node_name);
500
501 // Control dependencies should be treated as high priority. Current
502 // Channel device doesn't model a separate virual channel for control v/s
503 // data transfers. So in the interim, it may be okay to let control
504 // dependencies magically flow across devices bypassing the channel
505 // device.
506 if (curr_node_device == in_device || IsControlInput(input_node_name)) {
507 // Same device: connect input_node and curr_node directly.
508 curr_node_state.inputs.push_back(
509 std::make_pair(input_node, input_node_port_num));
510 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
511 input_node_state.outputs[input_node_port_num].push_back(curr_node);
512 } else {
513 RecvNodeDescriptor recv_node(input_node, input_node_port_num,
514 curr_node_device);
515 auto it = cached_recv_nodes.find(recv_node);
516 if (it != cached_recv_nodes.end()) {
517 // Different device, but found an already-cached copy (a _Recv op);
518 // connect the _Recv to curr_node.
519 const NodeDef* recv_op = it->second;
520 // recv_op's output port is hard-coded to zero.
521 curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
522 auto& input_node_state = node_map_.at(recv_op);
523 input_node_state.outputs[0].push_back(curr_node);
524 } else {
525 // Different device, no cached copy; transfer input_node to the
526 // curr_node's device.
527 auto send_and_recv =
528 CreateSendRecv(input_node, curr_node, input_node, input_node_name,
529 create_explicit_channel_device);
530 // Note that CreateSendRecv() already connected input/output between
531 // _Send and _Recv ops.
532 const auto* send = send_and_recv.first;
533 const auto* recv = send_and_recv.second;
534 // recv_op's output port is hard-coded to zero.
535 curr_node_state.inputs.push_back(std::make_pair(recv, 0));
536 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
537 input_node_state.outputs[input_node_port_num].push_back(send);
538
539 // Cache the _Recv op for future use.
540 cached_recv_nodes[recv_node] = recv;
541 }
542 }
543 }
544
545 // Special case: given feed nodes are ready at time 0.
546 const bool given_as_feed =
547 feed_nodes.find(curr_node->name()) != feed_nodes.end();
548
549 // Default case: node without inputs are ready at time 0.
550 // Note that we check inputs vector which may be different to
551 // curr_node->input(); e.g., we add Send as input to Recv.
552 const bool has_no_inputs = inputs.empty();
553
554 if (given_as_feed || has_no_inputs) {
555 curr_node_state.time_ready = Costs::Duration();
556 initial_nodes->push_back(curr_node);
557 VLOG(3) << "Added ready node: " << curr_node->name();
558 }
559 feed_nodes.erase(curr_node->name());
560
561 if (IsPersistent(*curr_node)) {
562 auto& device_state = device_[curr_node_device];
563 for (int port_num = 0,
564 port_num_end = curr_node_state.output_properties.size();
565 port_num < port_num_end; ++port_num) {
566 device_state.persistent_nodes.insert(
567 std::make_pair(curr_node, port_num));
568 }
569 }
570 }
571
572 if (initial_nodes->empty()) {
573 return errors::InvalidArgument("No ready nodes in the graph.");
574 }
575
576 if (!feed_nodes.empty()) {
577 // This isn't always a bug: when the caller hasn't specified the exact list
578 // of feed and fetch nodes, by default we consider all placeholders as feed
579 // nodes, but some of them may not be needed for the default fetch node.
580 VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
581 << absl::StrJoin(feed_nodes, ",");
582 }
583
584 initialized_ = true;
585 return Status::OK();
586 }
587
MaybeUpdateInputOutput(const NodeDef * node)588 void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
589 CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
590 // This method is called when NodeState is created and adds input and output
591 // properties for a few exceptional cases that GraphProperties cannot provide
592 // input/output properties.
593 if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
594 // _Send and _Recv ops created from SchedulerState have kAttrInputSrc
595 // attr; normal _Send and _Recv ops (from the input graph) do not have that
596 // attr.
597 auto& node_state = node_map_[node];
598 auto& inputs = node_state.input_properties;
599 auto& outputs = node_state.output_properties;
600
601 // _Send and _Recv ops are created from SchedulerState, so
602 // there should be no inputs TensorProperties.
603 CHECK(inputs.empty());
604 CHECK(outputs.empty());
605 const auto& attr = node->attr();
606 // This is the original input source to the _Send and _Recv, and this
607 // string includes "^" if it was control dependency, and output port
608 /// (e.g., ":2") if the input source had multiple outputs.
609 const auto& input_source_name = attr.at(kAttrInputSrc).s();
610 if (IsControlInput(input_source_name)) {
611 // Control dependency; regardless of the input source tensor size,
612 // send 4B.
613 OpInfo::TensorProperties control_message;
614 control_message.set_dtype(DT_FLOAT);
615 control_message.mutable_shape()->add_dim()->set_size(1);
616 auto* value = control_message.mutable_value();
617 value->add_float_val(1);
618 inputs.push_back(control_message);
619 outputs.push_back(control_message);
620 } else {
621 const auto& output_properties =
622 graph_properties_->GetOutputProperties(NodeName(input_source_name));
623 // Like with HasInputProperties, if a node does not have output
624 // properties, it's likely it was pruned during the shape inference run.
625 if (!output_properties.empty()) {
626 const auto input_node_port_num = NodePosition(input_source_name);
627 // Use the input source's output property as _Send and _Recv's input
628 // property.
629 CHECK_GT(output_properties.size(), input_node_port_num);
630 inputs.push_back(output_properties[input_node_port_num]);
631 outputs.push_back(output_properties[input_node_port_num]);
632 }
633 }
634 }
635 }
636
DeviceName(const NodeDef * node) const637 string SchedulerState::DeviceName(const NodeDef* node) const {
638 return placer_->get_canonical_device_name(*node);
639 }
640
SanitizedDeviceName(const NodeDef * node) const641 string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
642 // Replace the ":" characters that may be present in the device name with "_".
643 // This makes it possible to then use the resulting string in a node name.
644 return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
645 {{":", "_"}});
646 }
647
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const648 string SchedulerState::ChannelDeviceName(const NodeDef* from,
649 const NodeDef* to) const {
650 CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
651 return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
652 "_to_", SanitizedDeviceName(to));
653 }
654
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name,bool create_channel_device)655 std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
656 const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
657 const string& input_name, bool create_channel_device) {
658 CHECK(!initialized_) << "CreateSendRecv is called after Init().";
659
660 // Connect "from" node to "to" node with _Send and _Recv such that
661 // from -> _Send -> _Recv -> to.
662 // _Send is placed on "Channel" device, and _Recv is on the same device
663 // as "to" node.
664 // input_node_name is the string from the "to" node to identify which output
665 // we get from the "from" node.
666
667 // Note that we use NodeState for scheduling, so _Send and _Recv
668 // NodeDefs created here need not be correct: in terms of name,
669 // input names, attrs, etc.
670
671 auto input_node_port_num = NodePosition(input_name);
672 string src_name;
673 bool control_input = false;
674 if (input_node_port_num >= 0) {
675 src_name = absl::StrCat(from->name(), "_", input_node_port_num);
676 } else {
677 src_name = absl::StrCat(from->name(), "_minus1");
678 control_input = true;
679 }
680
681 // _Send op.
682 auto* send = new NodeDef();
683 send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
684 "_to_" + SanitizedDeviceName(to));
685 send->set_op("_Send");
686 send->add_input(from->name());
687 auto send_device =
688 create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
689 send->set_device(send_device);
690 auto& send_attr = *(send->mutable_attr());
691 send_attr[kAttrInputSrc].set_s(input_name);
692 send_attr[kAttrSrcDevice].set_s(DeviceName(from));
693 send_attr[kAttrDstDevice].set_s(DeviceName(to));
694 // GraphDef generated by AutoGrappler has tensor_name field when removing
695 // _Send/_Recv nodes.
696 if (input_node->attr().count(kAttrTensorName)) {
697 send_attr[kAttrTensorName].set_s(
698 input_node->attr().at(kAttrTensorName).s());
699 }
700
701 // _Recv op.
702 auto* recv = new NodeDef();
703 recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
704 recv->set_op("_Recv");
705 recv->add_input(send->name());
706 recv->set_device(DeviceName(to));
707 auto& recv_attr = *(recv->mutable_attr());
708 recv_attr[kAttrInputSrc].set_s(input_name);
709 if (input_node->attr().count(kAttrTensorName)) {
710 recv_attr[kAttrTensorName].set_s(
711 input_node->attr().at(kAttrTensorName).s());
712 }
713
714 // Propagate the streaming attribute to the send/recv nodes.
715 if (from->attr().contains(kStreaming) && !control_input) {
716 if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) {
717 LOG(ERROR)
718 << from->name()
719 << " port index larger than length of _streaming attribute list.";
720 } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) {
721 send_attr[kStreaming].mutable_list()->add_b(true);
722 recv_attr[kStreaming].mutable_list()->add_b(true);
723 }
724 }
725
726 // NodeState for _Send op.
727 auto& send_node_state = GetNodeStateOrCreateIt(send);
728 send_node_state.device_name = send->device(); // Set Channel device.
729 send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
730 send_node_state.outputs[0].push_back(recv);
731
732 // NodeState for _Recv op.
733 auto& recv_node_state = GetNodeStateOrCreateIt(recv);
734 recv_node_state.inputs.push_back(std::make_pair(send, 0));
735 recv_node_state.outputs[0].push_back(to);
736
737 // Keep the created nodes.
738 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
739 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
740
741 // Return _Send and _Recv.
742 return std::make_pair(send, recv);
743 }
744
CreateOpContext(const NodeDef * node) const745 OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
746 // Get the device from the placer.
747 DeviceProperties device;
748 device = placer_->get_device(*node);
749
750 // Special case for _Send op.
751 if (IsSend(*node)) {
752 device.set_type(kChannelDevice);
753 }
754
755 // Construct OpContext.
756 OpContext op_context;
757 const auto& node_state = node_map_.at(node);
758 op_context.name = node->name();
759 op_context.device_name = node_state.device_name;
760 auto& op_info = op_context.op_info;
761 op_info.set_op(node->op());
762 *op_info.mutable_attr() = node->attr();
763 for (auto& input : node_state.input_properties) {
764 *op_info.add_inputs() = input;
765 }
766 for (auto& output : node_state.output_properties) {
767 *op_info.add_outputs() = output;
768 }
769 op_info.mutable_device()->Swap(&device);
770
771 if (grappler_item_->graph.has_library()) {
772 op_context.function_library = &grappler_item_->graph.library();
773 }
774 return op_context;
775 }
776
GetNodeStateOrCreateIt(const NodeDef * node)777 NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
778 CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
779
780 auto it = node_map_.find(node);
781 if (it != node_map_.end()) {
782 return it->second;
783 }
784
785 // Not found; create a NodeState for this node.
786 it = node_map_.emplace(node, NodeState()).first;
787 auto& node_state = it->second;
788 node_state.input_properties =
789 graph_properties_->GetInputProperties(node->name());
790 node_state.output_properties =
791 graph_properties_->GetOutputProperties(node->name());
792 node_state.shape_incompatible =
793 graph_properties_->CheckShapeIncompatible(node->name());
794
795 // Some ops may need further processing to the input / output properties:
796 // _Send and _Recv.
797 MaybeUpdateInputOutput(node);
798
799 if (!IsSend(*node)) {
800 node_state.device_name = DeviceName(node);
801 // For _Send op, device_name will be set to Channel in CreateSendRecv().
802 }
803
804 // Initialize output port related data:
805 // Assume the size of OutputProperties represents the number of output ports
806 // of this node.
807 for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
808 node_state.time_no_references[i] = Costs::Duration::max();
809 node_state.num_outputs_executed[i] = 0;
810 // Populate an empty vector for each port. The caller will add nodes
811 // that use this port as input.
812 node_state.outputs[i] = {};
813 }
814 // Port_num -1 is for control dependency.
815 node_state.time_no_references[-1] = Costs::Duration::max();
816 node_state.num_outputs_executed[-1] = 0;
817 node_state.outputs[-1] = {};
818
819 // Initialize time_scheduled to infinity, so we know whether it has been
820 // assigned a non-default value later.
821 node_state.time_scheduled = Costs::Duration().infinity();
822
823 return it->second;
824 }
825
GetOutputNodes(const NodeDef * node,const Costs::Duration & curr_time,std::vector<const NodeDef * > * output_nodes)826 void SchedulerState::GetOutputNodes(const NodeDef* node,
827 const Costs::Duration& curr_time,
828 std::vector<const NodeDef*>* output_nodes) {
829 // Checks whether the Switch's output slots change over iterations.
830 int slot = -1;
831 if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
832 node->attr().at(kOutputSlots).list().i_size() > 0) {
833 slot = node->attr().at(kOutputSlots).list().i(0);
834 for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
835 if (slot != node->attr().at(kOutputSlots).list().i(i)) {
836 slot = -1;
837 break;
838 }
839 }
840 }
841 // Increment num_inputs_ready of the output nodes and maybe add to ready
842 // nodes.
843 auto& node_state = node_map_[node];
844 for (const auto& port_num_output_pair : node_state.outputs) {
845 // If Switch is annotated and its output slots are always the same, we only
846 // schedule the slot that was executed. Otherwise, scheduler both slots.
847 if (slot >= 0 && port_num_output_pair.first != slot) continue;
848
849 for (auto* output_node : port_num_output_pair.second) {
850 auto& output_state = node_map_[output_node];
851 output_state.num_inputs_ready++;
852 // Execute a node as soon as all its inputs are ready. Merge nodes are
853 // special since they run as soon as one of their inputs becomes
854 // available.
855 int output_state_inputs_size = output_state.inputs.size();
856 if (output_state.num_inputs_ready == output_state_inputs_size ||
857 IsMerge(*output_node)) {
858 // This output node is now ready.
859 output_state.time_ready = curr_time;
860 output_nodes->push_back(output_node);
861 VLOG(3) << " Add output: " << output_node->name();
862 }
863 }
864 }
865 }
866
MarkNodeExecuted(const NodeDef * node,const Costs & node_costs,const OpContext & op_context)867 std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
868 const NodeDef* node, const Costs& node_costs, const OpContext& op_context) {
869 auto& node_state = node_map_[node];
870 // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
871 // Merge -- it can create a loop which may include loop-carried dependency,
872 // diverge-merge, and other complex execution patterns.
873 bool previously_executed_merge =
874 IsMerge(*node) && (node_state.time_finished != Costs::Duration::max());
875
876 // If there is annotation in the graph about execution times, we use that
877 // number, otherwise, we assume the node is executed once.
878 node_state.execution_count = node->attr().count(kExecutionCount) == 0
879 ? 1
880 : node->attr().at(kExecutionCount).i();
881
882 node_state.node_costs = node_costs;
883 // TotalNodeCosts() Should be called after node_costs and execution_count.
884 Costs total_node_costs = node_state.TotalNodeCosts();
885
886 graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
887 const string& op_name = node->op();
888
889 auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
890 op_cost = CombineCosts(op_cost, total_node_costs);
891
892 if (VLOG_IS_ON(2)) {
893 // Also keep track of op counts and costs per op (with their shapes).
894 string node_description = GetOpDescription(op_context.op_info);
895 op_counts_[node_description] += 1;
896 op_costs_[node_description] =
897 std::make_pair(total_node_costs.execution_time.asMicroSeconds().count(),
898 !node_costs.inaccurate);
899 }
900
901 // Update node and device states.
902 auto& device = device_[node_state.device_name];
903 device.nodes_executed.push_back(node);
904 // Node is scheduled when the device is available AND all the inputs are
905 // ready; hence, time_scheduled is time_ready if time_ready > device curr
906 // time.
907 // NodeState times are assigned infinity at initialization. If they are
908 // still infinity here, we need to assign them. If not, it has been assigned
909 // already, so skip. This latter case may occur when a scheduler in-lines
910 // function calls, and thus schedules only function sub-nodes.
911 if (node_state.time_scheduled == Costs::Duration().infinity()) {
912 node_state.time_scheduled =
913 std::max(device.GetCurrTime(), node_state.time_ready);
914 // Override device curr time with the time_scheduled.
915 device.device_costs.execution_time = node_state.time_scheduled;
916 }
917 device.device_costs = CombineCosts(device.device_costs, total_node_costs);
918 auto curr_time = device.GetCurrTime();
919 node_state.time_finished = curr_time;
920
921 // Update shape annotation states.
922 UpdateDeviceAnnotationState(node, node_state, &device);
923
924 // Update device memory usage.
925 if (!IsPersistent(*node)) {
926 for (const auto& port_num_output_pair : node_state.outputs) {
927 int port_num = port_num_output_pair.first;
928
929 // There's a chance that a specific output is not used at all.
930 if (node_state.outputs[port_num].empty()) {
931 node_state.time_no_references[port_num] = curr_time;
932 } else {
933 // Streaming outputs do not allocate memory, they are directly consumed
934 // by the target node.
935 if (!IsStreamingPort(*node, port_num)) {
936 device.memory_usage +=
937 CalculateOutputSize(node_state.output_properties, port_num) *
938 node_state.execution_count;
939 }
940 device.nodes_in_memory.insert(std::make_pair(node, port_num));
941 }
942 }
943 }
944
945 // Update device's per-op cost.
946 auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
947 device_op_cost = CombineCosts(device_op_cost, total_node_costs);
948
949 VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
950 << ", device: " << node->device()
951 << ", execution_count: " << node_state.execution_count
952 << ", ready: " << node_state.time_ready.count()
953 << ", scheduled: " << node_state.time_scheduled.count()
954 << ", finished: " << node_state.time_finished.count();
955 std::vector<const NodeDef*> new_nodes;
956 if (previously_executed_merge) {
957 // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
958 VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
959 << "is executed more than once. "
960 << "Skip scheduling its output nodes.";
961 } else {
962 // Checks outputs, and adds ready nodes to queue.
963 GetOutputNodes(node, curr_time, &new_nodes);
964 }
965
966 // When op is scheduled, both input and output tensors must be allocated in
967 // memory. Now that output memory is added, check max memory usage.
968 if (!IsPersistent(*node)) {
969 if (device.memory_usage > device.max_memory_usage) {
970 device.max_memory_usage = device.memory_usage;
971
972 if (track_mem_usage_snapshot_) {
973 device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
974 }
975 }
976 }
977
978 // Increment num_outputs_executed of the input nodes and maybe update memory.
979 for (const auto& input_port : node_state.inputs) {
980 auto* input = input_port.first;
981 auto port = input_port.second;
982
983 auto& input_state = node_map_[input];
984 input_state.num_outputs_executed[port]++;
985 int input_state_outputs_size_ = input_state.outputs[port].size();
986 if (input_state.num_outputs_executed[port] == input_state_outputs_size_ &&
987 !IsPersistent(*input)) {
988 // All the outputs are executed; no reference to this output port of
989 // input node.
990 input_state.time_no_references[port] = curr_time;
991 auto& input_device = device_[input_state.device_name];
992 // If the node input is marked as streaming, then it wasn't allocated
993 // in memory. A streaming input is still reference counted, but it doesn't
994 // de-allocate memory.
995 if (!IsStreamingPort(*input, port)) {
996 input_device.memory_usage -=
997 CalculateOutputSize(input_state.output_properties, port) *
998 node_state.execution_count;
999 }
1000
1001 input_device.nodes_in_memory.erase(std::make_pair(input, port));
1002 }
1003 }
1004
1005 return new_nodes;
1006 }
1007
Summary() const1008 Costs SchedulerState::Summary() const {
1009 // Overall statement about accuracy
1010 VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
1011 << graph_costs_.num_ops_with_unknown_shapes
1012 << " having unknown shapes";
1013
1014 // Print out basic execution summary.
1015 VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
1016 VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
1017 VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
1018 VLOG(1) << "Expected intermediate memory time: "
1019 << graph_costs_.intermediate_memory_time.count();
1020 VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
1021 VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
1022 VLOG(1) << "Expected max per-op streaming buffers: "
1023 << graph_costs_.max_per_op_streaming;
1024
1025 VLOG(1) << "Per-op execution time / compute time / memory time"
1026 << " / intermediate memory time:";
1027 for (const auto& op_cost_pair : op_to_cost_) {
1028 const auto& op = op_cost_pair.first;
1029 const auto& cost = op_cost_pair.second.execution_time.count();
1030 const auto& compute_cost = op_cost_pair.second.compute_time.count();
1031 const auto& memory_cost = op_cost_pair.second.memory_time.count();
1032 const auto& intermediate_memory_cost =
1033 op_cost_pair.second.intermediate_memory_time.count();
1034 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1035 if (cost) { // Skip printing out zero-cost ops.
1036 VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
1037 (is_op_cost_accurate ? ' ' : '~'), cost,
1038 compute_cost, memory_cost,
1039 intermediate_memory_cost);
1040 }
1041 }
1042
1043 // Print per device summary
1044 VLOG(1) << "Devices:";
1045 Costs critical_path_costs = Costs::ZeroCosts();
1046 std::vector<string> device_names;
1047 device_names.reserve(device_.size());
1048 for (auto& it : device_) {
1049 device_names.push_back(it.first);
1050 }
1051 std::sort(device_names.begin(), device_names.end());
1052
1053 for (const auto& name : device_names) {
1054 const auto& state = device_.at(name);
1055
1056 std::map<string, int64> op_to_memory;
1057 // First profile only persistent memory usage.
1058 int64 persistent_memory_usage = 0;
1059 std::set<string> persistent_ops;
1060 for (const auto& node_port : state.persistent_nodes) {
1061 const auto* node = node_port.first;
1062 const auto port = node_port.second;
1063 auto output_size = 0;
1064 // Check if the node is in the node_map. It may be that the node executed
1065 // on this device was executed by a different Scheduler.
1066 if (node_map_.find(node) != node_map_.end()) {
1067 output_size =
1068 CalculateOutputSize(node_map_.at(node).output_properties, port);
1069 }
1070 persistent_memory_usage += output_size;
1071 op_to_memory[node->op()] += output_size;
1072 persistent_ops.insert(node->op());
1073 }
1074 int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
1075 critical_path_costs.estimated_max_memory_per_device[name] =
1076 max_memory_usage;
1077
1078 const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
1079 VLOG(1) << "Device = " << name
1080 << ", num_nodes = " << state.nodes_executed.size()
1081 << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
1082 << "persistent = " << HumanReadableNumBytes(persistent_memory_usage)
1083 << ", peak = " << HumanReadableNumBytes(state.max_memory_usage)
1084 << ", total = " << HumanReadableNumBytes(max_memory_usage)
1085 << ", at the end: " << HumanReadableNumBytes(state.memory_usage);
1086
1087 // Overall statement about accuracy
1088 VLOG(1) << state.device_costs.num_ops_total
1089 << " ops processed in total, with "
1090 << state.device_costs.num_ops_with_unknown_shapes
1091 << " having unknown shapes";
1092
1093 // Device shape annotation statistics.
1094 const auto& device_annotation_stats = state.shape_annotation_stats;
1095 if (device_annotation_stats.num_ops_annotated > 0) {
1096 VLOG(1) << device_annotation_stats.num_ops_annotated
1097 << " ops with shape annotation, with "
1098 << device_annotation_stats.num_ops_executed_more_than_once
1099 << " executed more than once, "
1100 << device_annotation_stats.num_ops_with_dynamic_shapes
1101 << " with dynamic shapes, "
1102 << device_annotation_stats.num_ops_with_incompatible_shapes
1103 << " with incompatible shapes, "
1104 << device_annotation_stats.num_ops_executed
1105 << " ops executed in total.";
1106 }
1107
1108 VLOG(1) << "Per-op execution time / compute time / memory time "
1109 << " / intermediate memory time"
1110 << " (and memory usage at peak memory usage):";
1111
1112 // Profile non-persistent op memory usage.
1113 for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
1114 const auto* node = node_port.first;
1115 const auto port = node_port.second;
1116 // Check if the node is in the node_map. It may be that the node executed
1117 // on this device was executed by a different Scheduler.
1118 if (node_map_.find(node) != node_map_.end()) {
1119 op_to_memory[node->op()] +=
1120 CalculateOutputSize(node_map_.at(node).output_properties, port);
1121 }
1122 }
1123 Costs::NanoSeconds total_compute_time_ns;
1124 bool is_total_cost_accurate = true;
1125 for (const auto& op_cost_pair : state.op_to_cost) {
1126 const auto& op = op_cost_pair.first;
1127 const auto& cost = op_cost_pair.second.execution_time.count();
1128 const auto& compute_cost = op_cost_pair.second.compute_time.count();
1129 const auto& memory_cost = op_cost_pair.second.memory_time.count();
1130 const auto& intermediate_memory_cost =
1131 op_cost_pair.second.intermediate_memory_time.count();
1132 total_compute_time_ns += op_cost_pair.second.execution_time;
1133 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1134 if (!is_op_cost_accurate) {
1135 is_total_cost_accurate = false;
1136 }
1137
1138 int64 op_mem_usage = 0;
1139 auto it = op_to_memory.find(op);
1140 if (it != op_to_memory.end()) {
1141 op_mem_usage = it->second;
1142 }
1143
1144 const float mem_usage_percent =
1145 max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
1146 : 0.0;
1147 if (cost || mem_usage_percent > 1.0) {
1148 // Print out only non-zero cost ops or ops with > 1% memory usage.
1149 VLOG(1) << absl::StrFormat(
1150 " + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
1151 (is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
1152 memory_cost, intermediate_memory_cost)
1153 << " (" << HumanReadableNumBytes(op_mem_usage) << " ["
1154 << mem_usage_percent << "%] "
1155 << (persistent_ops.count(op) > 0 ? ": persistent op)" : ")");
1156 }
1157 }
1158
1159 int utilization = 0;
1160 if (wall_time_ns.count() > 0) {
1161 utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
1162 }
1163 VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
1164 << (is_total_cost_accurate ? "" : "~")
1165 << total_compute_time_ns.count()
1166 << ", utilization = " << utilization << "%";
1167
1168 if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1169 critical_path_costs = state.device_costs;
1170 }
1171 }
1172
1173 if (VLOG_IS_ON(2)) {
1174 // Also log the op description and their corresponding counts.
1175 VLOG(2) << "Node description, counts, cost:";
1176 for (const auto& item : op_counts_) {
1177 int cost;
1178 bool is_cost_accurate;
1179 std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1180 VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1181 << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1182 << " us";
1183 }
1184 }
1185
1186 VLOG(1) << "Critical path execution time: "
1187 << critical_path_costs.execution_time.count();
1188 return critical_path_costs;
1189 }
1190
Summary(RunMetadata * metadata)1191 Costs SchedulerState::Summary(RunMetadata* metadata) {
1192 if (metadata) GenerateRunMetadata(metadata);
1193 return Summary();
1194 }
1195
GenerateRunMetadata(RunMetadata * metadata)1196 void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
1197 // Fill RunMetadata's step_stats and partition_graphs fields.
1198 StepStats* stepstats = metadata->mutable_step_stats();
1199 for (const auto& device : device_) {
1200 GraphDef* device_partition_graph = metadata->add_partition_graphs();
1201 DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1202 device_stepstats->set_device(device.first);
1203 for (const auto& node_def : device.second.nodes_executed) {
1204 // Only proceed if the node is in the node_map. This is to cover the case
1205 // where a device has executed a node that is not in the node_map of
1206 // this scheduler.
1207 if (node_map_.find(node_def) == node_map_.end()) {
1208 continue;
1209 }
1210 const NodeState& nodestate = node_map_.at(node_def);
1211 NodeExecStats* node_stats = device_stepstats->add_node_stats();
1212 uint64 total_output_size = 0;
1213 for (int slot = 0, slot_end = nodestate.output_properties.size();
1214 slot < slot_end; slot++) {
1215 const auto& properties = nodestate.output_properties[slot];
1216 NodeOutput* no = node_stats->add_output();
1217 no->set_slot(slot);
1218 TensorDescription* tensor_descr = no->mutable_tensor_description();
1219 tensor_descr->set_dtype(properties.dtype());
1220 *tensor_descr->mutable_shape() = properties.shape();
1221 // Optional allocation description.
1222 const auto tensor_size =
1223 CalculateOutputSize(nodestate.output_properties, slot);
1224 total_output_size += tensor_size;
1225 tensor_descr->mutable_allocation_description()->set_requested_bytes(
1226 tensor_size);
1227 tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1228 tensor_size);
1229 }
1230 if (node_def->op() != "HloGenericOp") {
1231 node_stats->set_timeline_label(node_def->op());
1232 } else {
1233 // For HloGenericOp, display hlo_opcode as timeline label.
1234 string timeline_label;
1235 if (node_def->attr().count("hlo_opcode") > 0) {
1236 absl::StrAppend(&timeline_label,
1237 node_def->attr().at("hlo_opcode").s());
1238 }
1239 if (node_def->attr().count("_hlo_metadata_op_type") > 0) {
1240 absl::StrAppend(&timeline_label, "/",
1241 node_def->attr().at("_hlo_metadata_op_type").s());
1242 }
1243 node_stats->set_timeline_label(timeline_label);
1244 }
1245 node_stats->set_node_name(node_def->name());
1246 // Timestamps in microseconds (can be used by timeline_server).
1247 node_stats->set_op_start_rel_micros(0);
1248 node_stats->set_all_start_micros(
1249 nodestate.time_scheduled.asMicroSeconds().count());
1250 node_stats->set_op_end_rel_micros(
1251 nodestate.time_finished.asMicroSeconds().count() -
1252 nodestate.time_scheduled.asMicroSeconds().count());
1253 node_stats->set_all_end_rel_micros(
1254 nodestate.time_finished.asMicroSeconds().count() -
1255 nodestate.time_scheduled.asMicroSeconds().count());
1256 // Timestamps in nanoseconds (can be used by xprof trace).
1257 node_stats->set_op_start_rel_nanos(0);
1258 node_stats->set_all_start_nanos(nodestate.time_scheduled.count());
1259 node_stats->set_op_end_rel_nanos(nodestate.time_finished.count() -
1260 nodestate.time_scheduled.count());
1261 node_stats->set_all_end_rel_nanos(nodestate.time_finished.count() -
1262 nodestate.time_scheduled.count());
1263
1264 auto* mem_stats = node_stats->mutable_memory_stats();
1265 // SchedulerState does not specify scratch pad memory usage.
1266 mem_stats->set_temp_memory_size(0);
1267 int64 persistent_memory_size = 0;
1268 if (IsPersistent(*node_def)) {
1269 persistent_memory_size = total_output_size;
1270 }
1271 mem_stats->set_persistent_memory_size(persistent_memory_size);
1272 *device_partition_graph->add_node() = *node_def;
1273 }
1274 }
1275 }
1276
GetPeakMemoryUsage() const1277 const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage()
1278 const {
1279 std::unordered_map<string, int64> result;
1280 for (const auto& device : device_) {
1281 const string& name = device.first;
1282 const DeviceState& state = device.second;
1283 result[name] = state.max_memory_usage;
1284 }
1285 return result;
1286 }
1287
1288 const std::unordered_map<string, int64>
GetPersistentMemoryUsage() const1289 SchedulerState::GetPersistentMemoryUsage() const {
1290 std::unordered_map<string, int64> result;
1291 for (const auto& device : device_) {
1292 const string& name = device.first;
1293 const DeviceState& state = device.second;
1294 int64 persistent_memory_usage = 0;
1295 for (const auto& node_port : state.persistent_nodes) {
1296 const auto* node = node_port.first;
1297 const auto port = node_port.second;
1298 const auto output_size =
1299 CalculateOutputSize(node_map_.at(node).output_properties, port);
1300 persistent_memory_usage += output_size;
1301 }
1302 result[name] = persistent_memory_usage;
1303 }
1304 return result;
1305 }
1306
SetNodeStateTimeScheduled(const NodeDef * node)1307 void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
1308 auto& node_state = node_map_.at(node);
1309 auto& device = device_[node_state.device_name];
1310 node_state.time_scheduled = device.GetCurrTime();
1311 }
1312
~VirtualScheduler()1313 VirtualScheduler::~VirtualScheduler() {}
1314
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes,std::unique_ptr<VirtualPlacer> placer)1315 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
1316 const bool use_aggressive_shape_inference,
1317 Cluster* cluster,
1318 ReadyNodeManager* ready_nodes,
1319 std::unique_ptr<VirtualPlacer> placer)
1320 : scheduler_state_(absl::make_unique<SchedulerState>(
1321 use_static_shapes, use_aggressive_shape_inference, cluster,
1322 std::move(placer))),
1323 ready_nodes_(ready_nodes) {}
1324
VirtualScheduler(ReadyNodeManager * ready_nodes,std::unique_ptr<SchedulerState> scheduler_state)1325 VirtualScheduler::VirtualScheduler(
1326 ReadyNodeManager* ready_nodes,
1327 std::unique_ptr<SchedulerState> scheduler_state)
1328 : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
1329
Init(const GrapplerItem * item)1330 Status VirtualScheduler::Init(const GrapplerItem* item) {
1331 // SchedulerState::Init() preprocesses the input grappler_item and
1332 // graph_properties to extract necessary information for emulating tensorflow
1333 // op scheduling and construct internal data structures (NodeState and
1334 // DeviceState) for virtual scheduling.
1335 TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
1336 std::vector<const NodeDef*> initial_nodes;
1337 auto status = scheduler_state_->Init(item, &initial_nodes);
1338 if (status.ok()) {
1339 // Add the set of initial nodes to ready_nodes_
1340 for (auto node : initial_nodes) {
1341 ready_nodes_->AddNode(node);
1342 }
1343 }
1344 return status;
1345 }
1346
GetCurrNode()1347 OpContext VirtualScheduler::GetCurrNode() {
1348 const NodeDef* node = ready_nodes_->GetCurrNode();
1349 return scheduler_state_->CreateOpContext(node);
1350 }
1351
MarkCurrNodeExecuted(const Costs & node_costs)1352 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
1353 // Update graph_costs_ and per-op costs.
1354 const NodeDef* node = ready_nodes_->GetCurrNode();
1355 auto new_nodes = scheduler_state_->MarkNodeExecuted(
1356 node, node_costs,
1357 scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
1358 // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
1359 for (auto node : new_nodes) {
1360 ready_nodes_->AddNode(node);
1361 }
1362 ready_nodes_->RemoveCurrNode();
1363 return !ready_nodes_->Empty();
1364 }
1365
1366 } // end namespace grappler
1367 } // end namespace tensorflow
1368