1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include "tensorflow/core/framework/allocation_description.pb.h"
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_description.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/grappler/clusters/utils.h"
25 #include "tensorflow/core/grappler/costs/utils.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/stringprintf.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34
35 namespace tensorflow {
36 namespace grappler {
37
38 namespace {
39
40 // Key to the cached _Recv ops map, and its hash and predicate structures.
41 struct RecvNodeDescriptor {
42 const NodeDef* node;
43 const int port_num;
44 const string device;
45
RecvNodeDescriptortensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptor46 RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
47 const string& device_)
48 : node(node_), port_num(port_num_), device(device_) {}
49 };
50
51 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptorHash52 std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
53 return std::hash<const NodeDef*>()(recv_node.node) ^
54 std::hash<int>()(recv_node.port_num) ^
55 std::hash<string>()(recv_node.device);
56 }
57 };
58
59 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anon1c764dd60111::RecvNodeDescriptorEqual60 bool operator()(const RecvNodeDescriptor& a,
61 const RecvNodeDescriptor& b) const {
62 return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
63 }
64 };
65 } // namespace
66
67 // ReadyNodeManager
GetCurrNode()68 const NodeDef* LIFOManager::GetCurrNode() {
69 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
70 if (curr_pos_ == nodes_.end()) {
71 curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list.
72 }
73 // Once curr_pos_ is set to a valid entry in the list, we keep using the
74 // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
75 // change the GetCurrNode() return value.
76 return *curr_pos_;
77 }
78
RemoveCurrNode()79 void LIFOManager::RemoveCurrNode() {
80 // Make sure we have curr_pos_ ready to be removed.
81 GetCurrNode();
82 // Note curr_pos_ may not be pointing the last element if some nodes are
83 // added.
84 nodes_.erase(curr_pos_);
85
86 curr_pos_ = nodes_.end(); // Reset curr_pos_.
87 }
88
FirstReadyManager()89 FirstReadyManager::FirstReadyManager() : ReadyNodeManager() {
90 std::make_heap(nodes_.begin(), nodes_.end());
91 }
92
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)93 void FirstReadyManager::Init(
94 const std::unordered_map<const NodeDef*, NodeState>* node_state) {
95 // Reset the node state since different instances of the scheduler can reuse
96 // the same node_manager.
97 node_state_ = node_state;
98 nodes_.clear();
99 waiting_queue_.clear();
100 greater_ = [this](const NodeDef* a, const NodeDef* b) -> bool {
101 if (node_state_->at(a).time_ready == node_state_->at(b).time_ready) {
102 // Use Node name as tie-breaker for deterministic node scheduling.
103 return a->name().compare(b->name()) > 0;
104 } else {
105 // Note: we need a node with minimum time_ready, not
106 // maximum; hence, using a > b for comparison function.
107 return node_state_->at(a).time_ready > node_state_->at(b).time_ready;
108 }
109 };
110 }
111
GetCurrNode()112 const NodeDef* FirstReadyManager::GetCurrNode() {
113 if (nodes_.empty()) {
114 // Nothing in the node_; probably, the very first call. Move
115 // waiting_queue_ to node_.
116 DrainWaitingQueue();
117 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
118 }
119 return nodes_.front();
120 }
121
RemoveCurrNode()122 void FirstReadyManager::RemoveCurrNode() {
123 if (nodes_.empty()) {
124 // Make sure that there is a node to be removed at the front of nodes_.
125 GetCurrNode();
126 }
127 std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
128 nodes_.pop_back();
129 DrainWaitingQueue();
130 }
131
Empty() const132 bool FirstReadyManager::Empty() const {
133 return nodes_.empty() && waiting_queue_.empty();
134 }
135
DrainWaitingQueue()136 void FirstReadyManager::DrainWaitingQueue() {
137 for (const auto* node : waiting_queue_) {
138 // push_heap in AddNode() and pop_heap in RemoveCurrNode() guarantees that
139 // the first element is the node with minimum time_ready.
140 nodes_.push_back(node);
141 std::push_heap(nodes_.begin(), nodes_.end(), greater_);
142 }
143 waiting_queue_.clear();
144 }
145
CompositeNodeManager()146 CompositeNodeManager::CompositeNodeManager()
147 : ReadyNodeManager(), send_manager_(), recv_manager_() {}
148
Init(const std::unordered_map<const NodeDef *,NodeState> * node_state)149 void CompositeNodeManager::Init(
150 const std::unordered_map<const NodeDef*, NodeState>* node_state) {
151 node_state_ = node_state;
152 send_manager_.Init(node_state);
153 recv_manager_.Init(node_state);
154 curr_node_ = nullptr;
155 }
156
AddNode(const NodeDef * node)157 void CompositeNodeManager::AddNode(const NodeDef* node) {
158 if (IsSend(*node)) {
159 send_manager_.AddNode(node);
160 } else if (IsRecv(*node)) {
161 recv_manager_.AddNode(node);
162 } else {
163 const auto& device = node_state_->at(node).device_name;
164 ops_lifo_map_[device].AddNode(node);
165 }
166 }
167
GetCurrNode()168 const NodeDef* CompositeNodeManager::GetCurrNode() {
169 if (curr_node_) return curr_node_;
170
171 // Per-device LIFO for normal ops (not _Send / _Recv),
172 // FirstReady for _Send and _Recv (separately),
173 // Globally (among the LIFO-selected ops from each device and _Send and
174 // _Recv) FirstReady,
175 // Priorty order: _Send, _Recv, and then the rest, if time_ready is equal.
176 std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
177 for (auto& ops_lifo : ops_lifo_map_) {
178 if (!ops_lifo.second.Empty()) {
179 const auto* op = ops_lifo.second.GetCurrNode();
180 candidates.emplace_back(op, node_state_->at(op).time_ready);
181 }
182 }
183 if (!send_manager_.Empty()) {
184 const auto* send = send_manager_.GetCurrNode();
185 candidates.emplace_back(send, node_state_->at(send).time_ready);
186 }
187 if (!recv_manager_.Empty()) {
188 const auto* recv = recv_manager_.GetCurrNode();
189 candidates.emplace_back(recv, node_state_->at(recv).time_ready);
190 }
191 CHECK(!candidates.empty());
192 auto first_ready = std::min_element(
193 candidates.begin(), candidates.end(),
194 [](const std::pair<const NodeDef*, Costs::Duration>& a,
195 const std::pair<const NodeDef*, Costs::Duration>& b) {
196 if (a.second == b.second) {
197 // Note that there can be only 1 Send and only 1 Recv in candidates,
198 // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
199 // normap op, and a_score and b_score are equal only if both are
200 // normal ops.
201 int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
202 int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
203 if (a_score == b_score) {
204 // Both are normal ops; use node name as tie breaker.
205 return a.first->name().compare(b.first->name()) < 0;
206 } else {
207 // Priortize by op type: _Send, _Recv, and normap ops.
208 return a_score > b_score;
209 }
210 } else {
211 return a.second < b.second;
212 }
213 });
214 // Next time we call GetCurrNode(), it just returns the cached one,
215 // curr_node_ until we call RemovCurrNode().
216 curr_node_ = first_ready->first;
217
218 return curr_node_;
219 }
220
RemoveCurrNode()221 void CompositeNodeManager::RemoveCurrNode() {
222 const auto* node = GetCurrNode();
223 if (IsSend(*node)) {
224 send_manager_.RemoveCurrNode();
225 } else if (IsRecv(*node)) {
226 recv_manager_.RemoveCurrNode();
227 } else {
228 const auto device = node_state_->at(node).device_name;
229 ops_lifo_map_[device].RemoveCurrNode();
230 }
231 // Reset curr_node_ so that GetCurrNode() finds another node.
232 curr_node_ = nullptr;
233 }
234
Empty() const235 bool CompositeNodeManager::Empty() const {
236 // Empty if all the ready managers are empty.
237 bool empty = true;
238 for (const auto& ops_lifo : ops_lifo_map_) {
239 empty &= ops_lifo.second.Empty();
240 }
241 return empty && send_manager_.Empty() && recv_manager_.Empty();
242 }
243
ReadyNodeManagerFactory(const string & ready_node_manager)244 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
245 const string& ready_node_manager) {
246 if (ready_node_manager == "FIFO") {
247 return absl::make_unique<FIFOManager>();
248 } else if (ready_node_manager == "LIFO") {
249 return absl::make_unique<LIFOManager>();
250 } else if (ready_node_manager == "FirstReady") {
251 return absl::make_unique<FirstReadyManager>();
252 } else if (ready_node_manager == "Composite") {
253 return absl::make_unique<CompositeNodeManager>();
254 }
255 LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
256 return nullptr;
257 }
258
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes)259 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
260 const bool use_aggressive_shape_inference,
261 Cluster* cluster,
262 ReadyNodeManager* ready_nodes)
263 : ready_nodes_(ready_nodes),
264 graph_costs_(Costs::ZeroCosts()),
265 cluster_(cluster),
266 use_static_shapes_(use_static_shapes),
267 use_aggressive_shape_inference_(use_aggressive_shape_inference),
268 placer_(cluster) {
269 graph_costs_.num_ops_total = 0;
270 initialized_ = false;
271 track_mem_usage_snapshot_ = VLOG_IS_ON(1);
272 }
273
Init(const GrapplerItem * item)274 Status VirtualScheduler::Init(const GrapplerItem* item) {
275 grappler_item_ = item;
276 graph_properties_ = absl::make_unique<GraphProperties>(*item);
277
278 initialized_ = false;
279
280 // Clear all internal states so that the VirtualScheduler is reusable for
281 // different GrapplerItems
282 node_map_.clear();
283 device_.clear();
284 additional_nodes_.clear();
285
286 graph_costs_ = Costs::ZeroCosts();
287 graph_costs_.num_ops_total = 0;
288 op_to_cost_.clear();
289
290 op_counts_.clear();
291 op_costs_.clear();
292
293 // Init() preprocesses the input grappler_item and graph_properties to extract
294 // necessary information for emulating tensorflow op scheduling and
295 // construct internal data structures (NodeState and DeviceState) for virtual
296 // scheduling.
297 ready_nodes_->Init(GetNodeStates());
298
299 // Construct graph properties.
300 if (use_static_shapes_) {
301 TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
302 true, use_aggressive_shape_inference_));
303 } else {
304 TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
305 }
306
307 const auto& graph = grappler_item_->graph;
308 const auto& fetch_nodes = grappler_item_->fetch;
309 std::set<string> feed_nodes;
310 for (const auto& f : grappler_item_->feed) {
311 auto iter_and_inserted_flag = feed_nodes.insert(f.first);
312 QCHECK(iter_and_inserted_flag.second)
313 << "Duplicate feed node found: " << f.first;
314 }
315
316 // Get the nodes that would run to output fetch_nodes.
317 bool ill_formed = false;
318 const std::vector<const NodeDef*> fetch_fanin_nodes =
319 ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed);
320 if (ill_formed) {
321 return errors::InvalidArgument(
322 "Ill formed graph or invalid set of fetch nodes specified");
323 }
324
325 // TODO(dyoon): this is a bit inefficient as name_to_node is already built in
326 // ComputeTransitiveFanin().
327 // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
328 // from the fetch nodes are scheduled. So the scheduled nodes should be
329 // exactly the same as those executed for real. One possible discrepancy could
330 // be the control flow nodes, where tf only executes one path.
331 std::unordered_map<string, const NodeDef*> name_to_node;
332 for (const auto& node : fetch_fanin_nodes) {
333 name_to_node[node->name()] = node;
334 }
335
336 // Traverses the graph to record _Send nodes.
337 // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
338 // to _Recv as control dependency when creating GrapplerItem.
339 std::unordered_map<string, const NodeDef*> name_to_send;
340 for (const auto& node : graph.node()) {
341 if (IsSend(node)) {
342 const auto& attr = node.attr();
343 name_to_send[attr.at("tensor_name").s()] = &node;
344 }
345 }
346
347 // To reuse _Recv ops.
348 std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
349 RecvNodeDescriptorEqual>
350 cached_recv_nodes;
351
352 // Build node_map; for each node, create its NodeState and connect its inputs
353 // and outputs.
354 for (const auto* curr_node : fetch_fanin_nodes) {
355 auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
356 const string curr_node_device = DeviceName(curr_node);
357 std::vector<string> inputs;
358 if (IsRecv(*curr_node)) {
359 const auto& attr = curr_node->attr();
360 if (attr.count("tensor_name")) {
361 const auto& send_node_name = attr.at("tensor_name").s();
362 auto it = name_to_send.find(send_node_name);
363 // If there is a _Send associated with the curr_node (_Recv), add it as
364 // input.
365 if (it != name_to_send.end()) {
366 const NodeDef* send = it->second;
367 inputs = {send->name()};
368 }
369 }
370 } else {
371 for (const string& input : curr_node->input()) {
372 inputs.push_back(input);
373 }
374 }
375 for (const string& input_node_name : inputs) {
376 // Note that input_node_name may be in <prefix><node_name>:<port_num>
377 // format, where <prefix> (e.g., "^" for control dependency) and
378 // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
379 const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
380
381 CHECK(input_node);
382 const string in_device = DeviceName(input_node);
383 const auto input_node_port_num = NodePosition(input_node_name);
384
385 if (curr_node_device == in_device) {
386 // Same device: connect input_node and curr_node directly.
387 curr_node_state.inputs.push_back(
388 std::make_pair(input_node, input_node_port_num));
389 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
390 input_node_state.outputs[input_node_port_num].push_back(curr_node);
391 } else {
392 RecvNodeDescriptor recv_node(input_node, input_node_port_num,
393 curr_node_device);
394 auto it = cached_recv_nodes.find(recv_node);
395 if (it != cached_recv_nodes.end()) {
396 // Different device, but found an already-cached copy (a _Recv op);
397 // connect the _Recv to curr_node.
398 const NodeDef* recv_op = it->second;
399 // recv_op's output port is hard-coded to zero.
400 curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
401 auto& input_node_state = node_map_.at(recv_op);
402 input_node_state.outputs[0].push_back(curr_node);
403 } else {
404 // Different device, no cached copy; transfer input_node to the
405 // curr_node's device.
406 auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
407 input_node_name);
408 // Note that CreateSendRecv() already connected input/output between
409 // _Send and _Recv ops.
410 const auto* send = send_and_recv.first;
411 const auto* recv = send_and_recv.second;
412 // recv_op's output port is hard-coded to zero.
413 curr_node_state.inputs.push_back(std::make_pair(recv, 0));
414 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
415 input_node_state.outputs[input_node_port_num].push_back(send);
416
417 // Cache the _Recv op for future use.
418 cached_recv_nodes[recv_node] = recv;
419 }
420 }
421 }
422
423 // Special case: given feed nodes are ready at time 0.
424 const bool given_as_feed =
425 feed_nodes.find(curr_node->name()) != feed_nodes.end();
426
427 // Default case: node without inputs are ready at time 0.
428 // Note that we check inputs vector which may be different to
429 // curr_node->input(); e.g., we add Send as input to Recv.
430 const bool has_no_inputs = inputs.empty();
431
432 if (given_as_feed || has_no_inputs) {
433 curr_node_state.time_ready = Costs::Duration();
434 ready_nodes_->AddNode(curr_node);
435 VLOG(3) << "Added ready node: " << curr_node->name();
436 }
437
438 feed_nodes.erase(curr_node->name());
439
440 if (IsPersistentNode(curr_node)) {
441 auto& device_state = device_[curr_node_device];
442 for (int port_num = 0;
443 port_num < curr_node_state.output_properties.size(); ++port_num) {
444 device_state.persistent_nodes.insert(
445 std::make_pair(curr_node, port_num));
446 }
447 }
448 }
449
450 if (ready_nodes_->Empty()) {
451 return errors::InvalidArgument("No ready nodes in the graph.");
452 }
453
454 if (!feed_nodes.empty()) {
455 // This isn't always a bug: when the caller hasn't specified the exact list
456 // of feed and fetch nodes, by default we consider all placeholders as feed
457 // nodes, but some of them may not be needed for the default fetch node.
458 VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
459 << str_util::Join(feed_nodes, ",");
460 }
461
462 initialized_ = true;
463 return Status::OK();
464 }
465
MaybeUpdateInputOutput(const NodeDef * node)466 void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) {
467 CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
468 // This method is called when NodeState is created and adds input and output
469 // properties for a few exceptional cases that GraphProperties cannot provide
470 // input/output properties.
471 if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
472 // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc
473 // attr; normal _Send and _Recv ops (from the input graph) do not have that
474 // attr.
475 auto& node_state = node_map_[node];
476 auto& inputs = node_state.input_properties;
477 auto& outputs = node_state.output_properties;
478
479 // _Send and _Recv ops are created from VirtualScheduler, so
480 // there should be no inputs TensorProperties.
481 CHECK(inputs.empty());
482 CHECK(outputs.empty());
483 const auto& attr = node->attr();
484 // This is the original input source to the _Send and _Recv, and this
485 // string includes "^" if it was control dependency, and output port
486 /// (e.g., ":2") if the input source had multiple outputs.
487 const auto& input_source_name = attr.at(kAttrInputSrc).s();
488 if (IsControlInput(input_source_name)) {
489 // Control dependency; regardless of the input source tensor size,
490 // send 4B.
491 OpInfo::TensorProperties control_message;
492 control_message.set_dtype(DT_FLOAT);
493 control_message.mutable_shape()->add_dim()->set_size(1);
494 auto* value = control_message.mutable_value();
495 value->add_float_val(1);
496 inputs.push_back(control_message);
497 outputs.push_back(control_message);
498 } else {
499 const auto& output_properties =
500 graph_properties_->GetOutputProperties(NodeName(input_source_name));
501 // Like with HasInputProperties, if a node does not have output
502 // properties, it's likely it was pruned during the shape inference run.
503 if (!output_properties.empty()) {
504 const auto input_node_port_num = NodePosition(input_source_name);
505 // Use the input source's output property as _Send and _Recv's input
506 // property.
507 CHECK_GT(output_properties.size(), input_node_port_num);
508 inputs.push_back(output_properties[input_node_port_num]);
509 outputs.push_back(output_properties[input_node_port_num]);
510 }
511 }
512 }
513 }
514
Round2(const float x) const515 float VirtualScheduler::Round2(const float x) const {
516 // Not using std::round from <cmath> here because not all platforms seem to
517 // support that (specifically Android).
518 return ::round(100.0 * x) / 100.0;
519 }
520
IsPersistentNode(const NodeDef * node) const521 bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
522 // Variables are persistent nodes.
523 return IsVariable(*node);
524 }
525
DeviceName(const NodeDef * node) const526 string VirtualScheduler::DeviceName(const NodeDef* node) const {
527 return placer_.get_canonical_device_name(*node);
528 }
529
SanitizedDeviceName(const NodeDef * node) const530 string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
531 // Replace the ":" characters that may be present in the device name with "_".
532 // This makes it possible to then use the resulting string in a node name.
533 return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
534 "_", true);
535 }
536
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const537 string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
538 const NodeDef* to) const {
539 CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
540 return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" +
541 SanitizedDeviceName(to);
542 }
543
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name)544 std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
545 const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
546 const string& input_name) {
547 CHECK(!initialized_) << "CreateSendRecv is called after Init().";
548
549 // Connect "from" node to "to" node with _Send and _Recv such that
550 // from -> _Send -> _Recv -> to.
551 // _Send is placed on "Channel" device, and _Recv is on the same device
552 // as "to" node.
553 // input_node_name is the string from the "to" node to identify which output
554 // we get from the "from" node.
555
556 // Note that we use NodeState for scheduling, so _Send and _Recv
557 // NodeDefs created here need not be correct: in terms of name,
558 // input names, attrs, etc.
559
560 auto input_node_port_num = NodePosition(input_name);
561 string src_name;
562 if (input_node_port_num >= 0) {
563 src_name = strings::StrCat(from->name(), "_", input_node_port_num);
564 } else {
565 src_name = strings::StrCat(from->name(), "_minus1");
566 }
567
568 // _Send op.
569 auto* send = new NodeDef();
570 send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
571 "_to_" + SanitizedDeviceName(to));
572 send->set_op("_Send");
573 send->add_input(from->name());
574 send->set_device(ChannelDeviceName(from, to));
575 auto& send_attr = *(send->mutable_attr());
576 send_attr[kAttrInputSrc].set_s(input_name);
577 send_attr[kAttrSrcDevice].set_s(DeviceName(from));
578 send_attr[kAttrDstDevice].set_s(DeviceName(to));
579 // GraphDef generated by AutoGrappler has tensor_name field when removing
580 // _Send/_Recv nodes.
581 if (input_node->attr().count(kAttrTensorName)) {
582 send_attr[kAttrTensorName].set_s(
583 input_node->attr().at(kAttrTensorName).s());
584 }
585
586 // _Recv op.
587 auto* recv = new NodeDef();
588 recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
589 recv->set_op("_Recv");
590 recv->add_input(send->name());
591 recv->set_device(DeviceName(to));
592 auto& recv_attr = *(recv->mutable_attr());
593 recv_attr[kAttrInputSrc].set_s(input_name);
594 if (input_node->attr().count(kAttrTensorName)) {
595 recv_attr[kAttrTensorName].set_s(
596 input_node->attr().at(kAttrTensorName).s());
597 }
598
599 // NodeState for _Send op.
600 auto& send_node_state = GetNodeStateOrCreateIt(send);
601 send_node_state.device_name = send->device(); // Set Channel device.
602 send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
603 send_node_state.outputs[0].push_back(recv);
604
605 // NodeState for _Recv op.
606 auto& recv_node_state = GetNodeStateOrCreateIt(recv);
607 recv_node_state.inputs.push_back(std::make_pair(send, 0));
608 recv_node_state.outputs[0].push_back(to);
609
610 // Keep the created nodes.
611 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
612 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
613
614 // Return _Send and _Recv.
615 return std::make_pair(send, recv);
616 }
617
GetCurrNode() const618 OpContext VirtualScheduler::GetCurrNode() const {
619 const NodeDef* node = ready_nodes_->GetCurrNode();
620
621 // Get the device from the placer.
622 DeviceProperties device;
623 device = placer_.get_device(*node);
624
625 // Special case for _Send op.
626 if (IsSend(*node)) {
627 device.set_type(kChannelDevice);
628 }
629
630 // Construct OpContext.
631 OpContext op_context;
632 const auto& node_state = node_map_.at(node);
633 op_context.name = node->name();
634 op_context.device_name = node_state.device_name;
635 auto& op_info = op_context.op_info;
636 op_info.set_op(node->op());
637 *op_info.mutable_attr() = node->attr();
638 for (auto& input : node_state.input_properties) {
639 *op_info.add_inputs() = input;
640 }
641 for (auto& output : node_state.output_properties) {
642 *op_info.add_outputs() = output;
643 }
644 op_info.mutable_device()->Swap(&device);
645
646 if (grappler_item_->graph.has_library()) {
647 op_context.function_library = &grappler_item_->graph.library();
648 }
649 return op_context;
650 }
651
GetNodeStateOrCreateIt(const NodeDef * node)652 NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
653 CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
654
655 auto it = node_map_.find(node);
656 if (it != node_map_.end()) {
657 return it->second;
658 }
659
660 // Not found; create a NodeState for this node.
661 it = node_map_.emplace(node, NodeState()).first;
662 auto& node_state = it->second;
663 node_state.input_properties =
664 graph_properties_->GetInputProperties(node->name());
665 node_state.output_properties =
666 graph_properties_->GetOutputProperties(node->name());
667
668 // Some ops may need further processing to the input / output properties:
669 // _Send and _Recv.
670 MaybeUpdateInputOutput(node);
671
672 if (!IsSend(*node)) {
673 node_state.device_name = DeviceName(node);
674 // For _Send op, device_name will be set to Channel in CreateSendRecv().
675 }
676
677 // Initialize output port related data:
678 // Assume the size of OutputProperties represents the number of output ports
679 // of this node.
680 for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
681 node_state.time_no_references[i] = Costs::Duration::max();
682 node_state.num_outputs_executed[i] = 0;
683 // Populate an empty vector for each port. The caller will add nodes
684 // that use this port as input.
685 node_state.outputs[i] = {};
686 }
687 // Port_num -1 is for control dependency.
688 node_state.time_no_references[-1] = Costs::Duration::max();
689 node_state.num_outputs_executed[-1] = 0;
690 node_state.outputs[-1] = {};
691
692 return it->second;
693 }
694
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)695 Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
696 std::map<string, Costs>* op_cost) {
697 auto it = op_cost->find(op_name);
698 if (it == op_cost->end()) {
699 // Note that default constructor of Costs sets some memory related fields
700 // to unknown values so we should explicitly initialize it with ZeroCosts.
701 it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
702 }
703 return it->second;
704 }
705
AddOutputNodesToReadyQueue(const NodeDef * node,const Costs::Duration & curr_time)706 void VirtualScheduler::AddOutputNodesToReadyQueue(
707 const NodeDef* node, const Costs::Duration& curr_time) {
708 // Checks whether the Switch's output slots change over iterations.
709 int slot = -1;
710 if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
711 node->attr().at(kOutputSlots).list().i_size() > 0) {
712 slot = node->attr().at(kOutputSlots).list().i(0);
713 for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
714 if (slot != node->attr().at(kOutputSlots).list().i(i)) {
715 slot = -1;
716 break;
717 }
718 }
719 }
720
721 // Increment num_inputs_ready of the output nodes and maybe add to ready
722 // nodes.
723 auto& node_state = node_map_[node];
724 for (const auto& port_num_output_pair : node_state.outputs) {
725 // If Switch is annotated and its output slots are always the same, we only
726 // schedule the slot that was executed. Otherwise, scheduler both slots.
727 if (slot >= 0 && port_num_output_pair.first != slot) continue;
728
729 for (auto* output_node : port_num_output_pair.second) {
730 auto& output_state = node_map_[output_node];
731 output_state.num_inputs_ready++;
732 // Execute a node as soon as all its inputs are ready. Merge nodes are
733 // special since they run as soon as one of their inputs becomes
734 // available.
735 if (output_state.num_inputs_ready == output_state.inputs.size() ||
736 IsMerge(*output_node)) {
737 // This output node is now ready.
738 output_state.time_ready = curr_time;
739 ready_nodes_->AddNode(output_node);
740 VLOG(3) << " Add output: " << output_node->name();
741 }
742 }
743 }
744 }
745
MarkCurrNodeExecuted(const Costs & node_costs)746 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
747 // Update graph_costs_ and per-op costs.
748 const NodeDef* node = ready_nodes_->GetCurrNode();
749 auto& node_state = node_map_[node];
750 // If there is annotation in the graph about execution times, we use that
751 // number, otherwise, we assume the node is executed once.
752 node_state.execution_count = node->attr().count(kExecutionCount) == 0
753 ? 1
754 : node->attr().at(kExecutionCount).i();
755 Costs total_node_costs =
756 MultiplyCosts(node_costs, node_state.execution_count);
757 graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
758 const string& op_name = node->op();
759
760 auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
761 op_cost = CombineCosts(op_cost, total_node_costs);
762
763 if (VLOG_IS_ON(2)) {
764 // Also keep track of op counts and costs per op (with their shapes).
765 OpContext op_context = GetCurrNode();
766
767 string node_description = GetOpDescription(op_context.op_info);
768 op_counts_[node_description] += 1;
769 op_costs_[node_description] =
770 std::make_pair(node_costs.execution_time.asMicroSeconds().count(),
771 !node_costs.inaccurate);
772 }
773
774 // Update node and device states.
775 auto& device = device_[node_state.device_name];
776 device.nodes_executed.push_back(node);
777 // Node is scheduled when the device is available AND all the inputs are
778 // ready; hence, time_scheduled is time_ready if time_ready > device curr
779 // time.
780 node_state.time_scheduled =
781 std::max(device.GetCurrTime(), node_state.time_ready);
782 // Override device curr time with the time_scheduled.
783 device.device_costs.execution_time = node_state.time_scheduled;
784 device.device_costs = CombineCosts(device.device_costs, total_node_costs);
785 auto curr_time = device.GetCurrTime();
786 node_state.time_finished = curr_time;
787
788 // Update device memory usage.
789 if (!IsPersistentNode(node)) {
790 for (const auto& port_num_output_pair : node_state.outputs) {
791 int port_num = port_num_output_pair.first;
792 // There's a chance that a specific output is not used at all.
793 if (node_state.outputs[port_num].empty()) {
794 node_state.time_no_references[port_num] = curr_time;
795 } else {
796 device.memory_usage +=
797 CalculateOutputSize(node_state.output_properties, port_num) *
798 node_state.execution_count;
799 device.nodes_in_memory.insert(std::make_pair(node, port_num));
800 }
801 }
802 }
803
804 // Update device's per-op cost.
805 auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
806 device_op_cost = CombineCosts(device_op_cost, total_node_costs);
807
808 VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
809 << ", device: " << node->device()
810 << ", execution_count: " << node_state.execution_count
811 << ", ready: " << node_state.time_ready.count()
812 << ", scheduled: " << node_state.time_scheduled.count()
813 << ", finished: " << node_state.time_finished.count();
814
815 // Checks outputs, and adds ready nodes to queue.
816 AddOutputNodesToReadyQueue(node, curr_time);
817
818 // Increment num_outputs_executed of the input nodes and maybe update memory.
819 for (const auto& input_port : node_state.inputs) {
820 auto* input = input_port.first;
821 auto port = input_port.second;
822 auto& input_state = node_map_[input];
823 input_state.num_outputs_executed[port]++;
824 if (input_state.num_outputs_executed[port] ==
825 input_state.outputs[port].size() &&
826 !IsPersistentNode(input)) {
827 // All the outputs are executed; no reference to this output port of
828 // input node.
829 input_state.time_no_references[port] = curr_time;
830 auto& input_device = device_[input_state.device_name];
831 input_device.memory_usage -=
832 CalculateOutputSize(input_state.output_properties, port) *
833 node_state.execution_count;
834
835 input_device.nodes_in_memory.erase(std::make_pair(input, port));
836 }
837 }
838
839 if (!IsPersistentNode(node)) {
840 // Now that output memory is added and used up nodes are deallocated,
841 // check max memory usage.
842 if (device.memory_usage > device.max_memory_usage) {
843 device.max_memory_usage = device.memory_usage;
844
845 if (track_mem_usage_snapshot_) {
846 device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
847 }
848 }
849 }
850
851 ready_nodes_->RemoveCurrNode();
852
853 return !ready_nodes_->Empty();
854 }
855
Summary() const856 Costs VirtualScheduler::Summary() const {
857 // Overall statement about accuracy
858 VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
859 << graph_costs_.num_ops_with_unknown_shapes
860 << " having unknown shapes";
861
862 // Print out basic execution summary.
863 VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
864 VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
865 VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
866 VLOG(1) << "Expected intermediate memory time: "
867 << graph_costs_.intermediate_memory_time.count();
868 VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
869 VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
870 VLOG(1) << "Expected max per-op streaming buffers: "
871 << graph_costs_.max_per_op_streaming;
872
873 VLOG(1) << "Per-op execution time / compute time / memory time"
874 << " / intermediate memory time:";
875 for (const auto& op_cost_pair : op_to_cost_) {
876 const auto& op = op_cost_pair.first;
877 const auto& cost = op_cost_pair.second.execution_time.count();
878 const auto& compute_cost = op_cost_pair.second.compute_time.count();
879 const auto& memory_cost = op_cost_pair.second.memory_time.count();
880 const auto& intermediate_memory_cost =
881 op_cost_pair.second.intermediate_memory_time.count();
882 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
883 if (cost) { // Skip printing out zero-cost ops.
884 VLOG(1) << strings::Printf(
885 " + %30s : %c %10lld / %10lld / %10lld / %10lld", op.c_str(),
886 (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
887 static_cast<int64>(compute_cost), static_cast<int64>(memory_cost),
888 static_cast<int64>(intermediate_memory_cost));
889 }
890 }
891
892 // Print per device summary
893 VLOG(1) << "Devices:";
894 Costs critical_path_costs = Costs::ZeroCosts();
895 std::vector<string> device_names;
896 device_names.reserve(device_.size());
897 for (auto& it : device_) {
898 device_names.push_back(it.first);
899 }
900 std::sort(device_names.begin(), device_names.end());
901
902 for (const auto& name : device_names) {
903 const auto& state = device_.at(name);
904
905 std::map<string, int64> op_to_memory;
906 // First profile only persistent memory usage.
907 int64 persistent_memory_usage = 0;
908 std::set<string> persisent_ops;
909 for (const auto& node_port : state.persistent_nodes) {
910 const auto* node = node_port.first;
911 const auto port = node_port.second;
912 const auto output_size =
913 CalculateOutputSize(node_map_.at(node).output_properties, port);
914 persistent_memory_usage += output_size;
915 op_to_memory[node->op()] += output_size;
916 persisent_ops.insert(node->op());
917 }
918 int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage;
919 critical_path_costs.estimated_max_memory_per_device[name] =
920 max_memory_usage;
921
922 const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
923 VLOG(1) << "Device = " << name
924 << ", num_nodes = " << state.nodes_executed.size()
925 << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
926 << "persistent = "
927 << strings::HumanReadableNumBytes(persistent_memory_usage)
928 << ", peak = "
929 << strings::HumanReadableNumBytes(state.max_memory_usage)
930 << ", total = " << strings::HumanReadableNumBytes(max_memory_usage)
931 << ", at the end: "
932 << strings::HumanReadableNumBytes(state.memory_usage);
933
934 // Overall statement about accuracy
935 VLOG(1) << state.device_costs.num_ops_total
936 << " ops processed in total, with "
937 << state.device_costs.num_ops_with_unknown_shapes
938 << " having unknown shapes";
939
940 VLOG(1) << "Per-op execution time / compute time / memory time "
941 << " / intermediate memory time"
942 << " (and memory usage at peak memory usage):";
943
944 // Profile non-persistent op memory usage.
945 for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
946 const auto* node = node_port.first;
947 const auto port = node_port.second;
948 op_to_memory[node->op()] +=
949 CalculateOutputSize(node_map_.at(node).output_properties, port);
950 }
951 Costs::NanoSeconds total_compute_time_ns;
952 bool is_total_cost_accurate = true;
953 for (const auto& op_cost_pair : state.op_to_cost) {
954 const auto& op = op_cost_pair.first;
955 const auto& cost = op_cost_pair.second.execution_time.count();
956 const auto& compute_cost = op_cost_pair.second.compute_time.count();
957 const auto& memory_cost = op_cost_pair.second.memory_time.count();
958 const auto& intermediate_memory_cost =
959 op_cost_pair.second.intermediate_memory_time.count();
960 total_compute_time_ns += op_cost_pair.second.execution_time;
961 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
962 if (!is_op_cost_accurate) {
963 is_total_cost_accurate = false;
964 }
965
966 int64 op_mem_usage = 0;
967 auto it = op_to_memory.find(op);
968 if (it != op_to_memory.end()) {
969 op_mem_usage = it->second;
970 }
971
972 const float mem_usage_percent =
973 max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
974 : 0.0;
975 if (cost || mem_usage_percent > 1.0) {
976 // Print out only non-zero cost ops or ops with > 1% memory usage.
977 VLOG(1) << strings::Printf(
978 " + %30s : %c %10lld / %10lld / %10lld / %10lld",
979 op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
980 static_cast<int64>(cost),
981 static_cast<int64>(compute_cost),
982 static_cast<int64>(memory_cost),
983 static_cast<int64>(intermediate_memory_cost))
984 << " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
985 << mem_usage_percent << "%] "
986 << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
987 }
988 }
989
990 int utilization = 0;
991 if (wall_time_ns.count() > 0) {
992 utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
993 }
994 VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
995 << (is_total_cost_accurate ? "" : "~")
996 << total_compute_time_ns.count()
997 << ", utilization = " << utilization << "%";
998
999 if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1000 critical_path_costs = state.device_costs;
1001 }
1002 }
1003
1004 if (VLOG_IS_ON(2)) {
1005 // Also log the op description and their corresponding counts.
1006 VLOG(2) << "Node description, counts, cost:";
1007 for (const auto& item : op_counts_) {
1008 int cost;
1009 bool is_cost_accurate;
1010 std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1011 VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1012 << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1013 << " us";
1014 }
1015 }
1016
1017 VLOG(1) << "Critical path execution time: "
1018 << critical_path_costs.execution_time.count();
1019 return critical_path_costs;
1020 }
1021
Summary(RunMetadata * metadata)1022 Costs VirtualScheduler::Summary(RunMetadata* metadata) {
1023 if (metadata) GenerateRunMetadata(metadata);
1024 return Summary();
1025 }
1026
GenerateRunMetadata(RunMetadata * metadata)1027 void VirtualScheduler::GenerateRunMetadata(RunMetadata* metadata) {
1028 // Fill RunMetadata's step_stats and partition_graphs fields.
1029 StepStats* stepstats = metadata->mutable_step_stats();
1030 for (const auto& device : device_) {
1031 GraphDef* device_partition_graph = metadata->add_partition_graphs();
1032 DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1033 device_stepstats->set_device(device.first);
1034 for (const auto& node_def : device.second.nodes_executed) {
1035 const NodeState& nodestate = node_map_.at(node_def);
1036 NodeExecStats* node_stats = device_stepstats->add_node_stats();
1037 uint64 total_output_size = 0;
1038 for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
1039 const auto& properties = nodestate.output_properties[slot];
1040 NodeOutput* no = node_stats->add_output();
1041 no->set_slot(slot);
1042 TensorDescription* tensor_descr = no->mutable_tensor_description();
1043 tensor_descr->set_dtype(properties.dtype());
1044 *tensor_descr->mutable_shape() = properties.shape();
1045 // Optional allocation description.
1046 const auto tensor_size =
1047 CalculateOutputSize(nodestate.output_properties, slot);
1048 total_output_size += tensor_size;
1049 tensor_descr->mutable_allocation_description()->set_requested_bytes(
1050 tensor_size);
1051 tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1052 tensor_size);
1053 }
1054 node_stats->set_timeline_label(node_def->op());
1055 node_stats->set_node_name(node_def->name());
1056 node_stats->set_op_start_rel_micros(0);
1057 node_stats->set_all_start_micros(
1058 nodestate.time_scheduled.asMicroSeconds().count());
1059 node_stats->set_op_end_rel_micros(
1060 nodestate.time_finished.asMicroSeconds().count() -
1061 nodestate.time_scheduled.asMicroSeconds().count());
1062 node_stats->set_all_end_rel_micros(
1063 nodestate.time_finished.asMicroSeconds().count() -
1064 nodestate.time_scheduled.asMicroSeconds().count());
1065 auto* mem_stats = node_stats->mutable_memory_stats();
1066 // VirtualScheduler does not specify scratch pad memory usage.
1067 mem_stats->set_temp_memory_size(0);
1068 int64 persistent_memory_size = 0;
1069 if (IsPersistentNode(node_def)) {
1070 persistent_memory_size = total_output_size;
1071 }
1072 mem_stats->set_persistent_memory_size(persistent_memory_size);
1073 *device_partition_graph->add_node() = *node_def;
1074 }
1075 }
1076 }
1077
GetPeakMemoryUsage() const1078 const std::unordered_map<string, int64> VirtualScheduler::GetPeakMemoryUsage()
1079 const {
1080 std::unordered_map<string, int64> result;
1081 for (const auto& device : device_) {
1082 const string& name = device.first;
1083 const DeviceState& state = device.second;
1084 result[name] = state.max_memory_usage;
1085 }
1086 return result;
1087 }
1088
1089 } // end namespace grappler
1090 } // end namespace tensorflow
1091