1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
17
18 #include <deque>
19
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/grappler/costs/graph_properties.h"
22 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28
29 namespace tensorflow {
30 namespace grappler {
31
PredictExecutionTime(const GraphProperties & properties,const OpLevelCostEstimator & estimator,const VirtualPlacer & placer,const NodeDef & node)32 static Costs::NanoSeconds PredictExecutionTime(
33 const GraphProperties& properties, const OpLevelCostEstimator& estimator,
34 const VirtualPlacer& placer, const NodeDef& node) {
35 OpContext op_context;
36 op_context.op_info.set_op(node.op());
37 *op_context.op_info.mutable_attr() = node.attr();
38
39 std::vector<OpInfo::TensorProperties> inputs =
40 properties.GetInputProperties(node.name());
41 for (auto& input : inputs) {
42 op_context.op_info.add_inputs()->Swap(&input);
43 }
44
45 std::vector<OpInfo::TensorProperties> outputs =
46 properties.GetOutputProperties(node.name());
47 for (auto& output : outputs) {
48 op_context.op_info.add_outputs()->Swap(&output);
49 }
50
51 DeviceProperties device = placer.get_device(node);
52 op_context.op_info.mutable_device()->Swap(&device);
53
54 Costs::NanoSeconds estimate =
55 estimator.PredictCosts(op_context).execution_time;
56
57 // Make sure our estimates are at least one nanosecond per node.
58 return std::max(estimate, Costs::NanoSeconds(1));
59 }
60
EstimateEarliestExecutionTimes(const GrapplerItem & item,const Cluster * cluster,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * completion_times)61 Status EstimateEarliestExecutionTimes(
62 const GrapplerItem& item, const Cluster* cluster,
63 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) {
64 std::unordered_map<string, const NodeDef*> name_map;
65 std::unordered_map<const NodeDef*, int> pending_inputs;
66 std::deque<const NodeDef*> ready_nodes;
67 for (const NodeDef& node : item.graph.node()) {
68 name_map[node.name()] = &node;
69 if (node.input_size() == 0) {
70 ready_nodes.push_back(&node);
71 (*completion_times)[&node] = 0;
72 } else if (IsMerge(node)) {
73 // Merge nodes are processed as soon as one of the input becomes
74 // available.
75 pending_inputs[&node] = 1;
76 } else {
77 pending_inputs[&node] = node.input_size();
78 }
79 }
80
81 std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts;
82 for (const NodeDef& node : item.graph.node()) {
83 for (const string& input : node.input()) {
84 string node_name = NodeName(input);
85 auto it = name_map.find(node_name);
86 if (it == name_map.end()) {
87 return errors::InvalidArgument(
88 strings::StrCat("Unknown input node ", input));
89 }
90 const NodeDef* fanin = it->second;
91 fanouts[fanin].push_back(&node);
92 }
93 }
94 name_map.clear();
95
96 GraphProperties properties(item);
97 TF_RETURN_IF_ERROR(
98 properties.InferStatically(/*assume_valid_feeds=*/true,
99 /*aggressive_shape_inference=*/false,
100 /*include_tensor_values=*/false));
101 OpLevelCostEstimator estimator;
102 VirtualPlacer placer(cluster->GetDevices());
103
104 while (!ready_nodes.empty()) {
105 const NodeDef* node = ready_nodes.front();
106 ready_nodes.pop_front();
107
108 Costs::NanoSeconds execution_time =
109 PredictExecutionTime(properties, estimator, placer, *node);
110 Costs::NanoSeconds completion_time =
111 execution_time + (*completion_times)[node];
112 (*completion_times)[node] = completion_time;
113
114 for (const NodeDef* fanout : fanouts[node]) {
115 int pending = pending_inputs[fanout];
116 if (pending == 0) {
117 // Already processed. Avoid going through loops more than once.
118 continue;
119 } else if (pending == 1) {
120 ready_nodes.push_back(fanout);
121 }
122 pending_inputs[fanout]--;
123
124 Costs::NanoSeconds ready_time =
125 std::max(completion_time, (*completion_times)[fanout]);
126 (*completion_times)[fanout] = ready_time;
127 }
128 }
129
130 return OkStatus();
131 }
132
EstimateRequiredTimes(const GrapplerItem & item,const Cluster * cluster,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * required_times)133 Status EstimateRequiredTimes(
134 const GrapplerItem& item, const Cluster* cluster,
135 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
136 execution_times,
137 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
138 std::unordered_map<string, const NodeDef*> name_map;
139 for (const NodeDef& node : item.graph.node()) {
140 name_map[node.name()] = &node;
141 (*required_times)[&node] = Costs::NanoSeconds::max();
142 }
143
144 std::unordered_map<const NodeDef*, int> pending_fanouts;
145 for (const NodeDef& node : item.graph.node()) {
146 for (const string& input : node.input()) {
147 string node_name = NodeName(input);
148 auto it = name_map.find(node_name);
149 if (it == name_map.end()) {
150 return errors::InvalidArgument(
151 strings::StrCat("Unknown input node ", input));
152 }
153 const NodeDef* fanin = it->second;
154 pending_fanouts[fanin] += 1;
155 }
156 }
157 std::deque<const NodeDef*> ready_nodes;
158 for (const NodeDef& node : item.graph.node()) {
159 if (pending_fanouts[&node] == 0) {
160 auto it = execution_times.find(&node);
161 if (it != execution_times.end()) {
162 (*required_times)[&node] = it->second;
163 }
164 ready_nodes.push_back(&node);
165 }
166 }
167 GraphProperties properties(item);
168 TF_RETURN_IF_ERROR(
169 properties.InferStatically(/*assume_valid_feeds=*/true,
170 /*aggressive_shape_inference=*/false,
171 /*include_tensor_values=*/false));
172 OpLevelCostEstimator estimator;
173 VirtualPlacer placer(cluster->GetDevices());
174
175 while (!ready_nodes.empty()) {
176 const NodeDef* node = ready_nodes.front();
177 ready_nodes.pop_front();
178
179 Costs::NanoSeconds execution_time =
180 PredictExecutionTime(properties, estimator, placer, *node);
181 Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
182
183 for (const string& fanin_name : node->input()) {
184 const NodeDef* fanin = name_map[NodeName(fanin_name)];
185 (*required_times)[fanin] =
186 std::min((*required_times)[fanin], required_time);
187
188 int pending = pending_fanouts[fanin];
189 if (pending == 0) {
190 // Already processed. Avoid going through loops more than once.
191 continue;
192 } else if (pending == 1) {
193 ready_nodes.push_back(fanin);
194 }
195 pending_fanouts[fanin]--;
196 }
197 }
198
199 return OkStatus();
200 }
201
202 } // end namespace grappler
203 } // end namespace tensorflow
204