1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
17 #include <deque>
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/grappler/costs/graph_properties.h"
20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
21 #include "tensorflow/core/grappler/costs/virtual_placer.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/grappler/utils.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
PredictExecutionTime(const GraphProperties & properties,const OpLevelCostEstimator & estimator,const VirtualPlacer & placer,const NodeDef & node)30 static Costs::NanoSeconds PredictExecutionTime(
31 const GraphProperties& properties, const OpLevelCostEstimator& estimator,
32 const VirtualPlacer& placer, const NodeDef& node) {
33 OpContext op_context;
34 op_context.op_info.set_op(node.op());
35 *op_context.op_info.mutable_attr() = node.attr();
36
37 std::vector<OpInfo::TensorProperties> inputs =
38 properties.GetInputProperties(node.name());
39 for (auto& input : inputs) {
40 op_context.op_info.add_inputs()->Swap(&input);
41 }
42
43 std::vector<OpInfo::TensorProperties> outputs =
44 properties.GetOutputProperties(node.name());
45 for (auto& output : outputs) {
46 op_context.op_info.add_outputs()->Swap(&output);
47 }
48
49 DeviceProperties device = placer.get_device(node);
50 op_context.op_info.mutable_device()->Swap(&device);
51
52 Costs::NanoSeconds estimate =
53 estimator.PredictCosts(op_context).execution_time;
54
55 // Make sure our estimates are at least one nanosecond per node.
56 return std::max(estimate, Costs::NanoSeconds(1));
57 }
58
EstimateEarliestExecutionTimes(const GrapplerItem & item,const Cluster * cluster,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * completion_times)59 Status EstimateEarliestExecutionTimes(
60 const GrapplerItem& item, const Cluster* cluster,
61 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) {
62 std::unordered_map<string, const NodeDef*> name_map;
63 std::unordered_map<const NodeDef*, int> pending_inputs;
64 std::deque<const NodeDef*> ready_nodes;
65 for (const NodeDef& node : item.graph.node()) {
66 name_map[node.name()] = &node;
67 if (node.input_size() == 0) {
68 ready_nodes.push_back(&node);
69 (*completion_times)[&node] = 0;
70 } else if (IsMerge(node)) {
71 // Merge nodes are processed as soon as one of the input becomes
72 // available.
73 pending_inputs[&node] = 1;
74 } else {
75 pending_inputs[&node] = node.input_size();
76 }
77 }
78
79 std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts;
80 for (const NodeDef& node : item.graph.node()) {
81 for (const string& input : node.input()) {
82 string node_name = NodeName(input);
83 auto it = name_map.find(node_name);
84 if (it == name_map.end()) {
85 return errors::InvalidArgument(
86 strings::StrCat("Unknown input node ", input));
87 }
88 const NodeDef* fanin = it->second;
89 fanouts[fanin].push_back(&node);
90 }
91 }
92 name_map.clear();
93
94 GraphProperties properties(item);
95 TF_RETURN_IF_ERROR(properties.InferStatically(true));
96 OpLevelCostEstimator estimator;
97 VirtualPlacer placer(cluster);
98
99 while (!ready_nodes.empty()) {
100 const NodeDef* node = ready_nodes.front();
101 ready_nodes.pop_front();
102
103 Costs::NanoSeconds execution_time =
104 PredictExecutionTime(properties, estimator, placer, *node);
105 Costs::NanoSeconds completion_time =
106 execution_time + (*completion_times)[node];
107 (*completion_times)[node] = completion_time;
108
109 for (const NodeDef* fanout : fanouts[node]) {
110 int pending = pending_inputs[fanout];
111 if (pending == 0) {
112 // Already processed. Avoid going through loops more than once.
113 continue;
114 } else if (pending == 1) {
115 ready_nodes.push_back(fanout);
116 }
117 pending_inputs[fanout]--;
118
119 Costs::NanoSeconds ready_time =
120 std::max(completion_time, (*completion_times)[fanout]);
121 (*completion_times)[fanout] = ready_time;
122 }
123 }
124
125 return Status::OK();
126 }
127
EstimateRequiredTimes(const GrapplerItem & item,const Cluster * cluster,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * required_times)128 Status EstimateRequiredTimes(
129 const GrapplerItem& item, const Cluster* cluster,
130 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
131 execution_times,
132 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
133 std::unordered_map<string, const NodeDef*> name_map;
134 for (const NodeDef& node : item.graph.node()) {
135 name_map[node.name()] = &node;
136 (*required_times)[&node] = Costs::NanoSeconds::max();
137 }
138
139 std::unordered_map<const NodeDef*, int> pending_fanouts;
140 for (const NodeDef& node : item.graph.node()) {
141 for (const string& input : node.input()) {
142 string node_name = NodeName(input);
143 auto it = name_map.find(node_name);
144 if (it == name_map.end()) {
145 return errors::InvalidArgument(
146 strings::StrCat("Unknown input node ", input));
147 }
148 const NodeDef* fanin = it->second;
149 pending_fanouts[fanin] += 1;
150 }
151 }
152 std::deque<const NodeDef*> ready_nodes;
153 for (const NodeDef& node : item.graph.node()) {
154 if (pending_fanouts[&node] == 0) {
155 auto it = execution_times.find(&node);
156 if (it != execution_times.end()) {
157 (*required_times)[&node] = it->second;
158 }
159 ready_nodes.push_back(&node);
160 }
161 }
162 GraphProperties properties(item);
163 TF_RETURN_IF_ERROR(properties.InferStatically(true));
164 OpLevelCostEstimator estimator;
165 VirtualPlacer placer(cluster);
166
167 while (!ready_nodes.empty()) {
168 const NodeDef* node = ready_nodes.front();
169 ready_nodes.pop_front();
170
171 Costs::NanoSeconds execution_time =
172 PredictExecutionTime(properties, estimator, placer, *node);
173 Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
174
175 for (const string& fanin_name : node->input()) {
176 const NodeDef* fanin = name_map[NodeName(fanin_name)];
177 (*required_times)[fanin] =
178 std::min((*required_times)[fanin], required_time);
179
180 int pending = pending_fanouts[fanin];
181 if (pending == 0) {
182 // Already processed. Avoid going through loops more than once.
183 continue;
184 } else if (pending == 1) {
185 ready_nodes.push_back(fanin);
186 }
187 pending_fanouts[fanin]--;
188 }
189 }
190
191 return Status::OK();
192 }
193
194 } // end namespace grappler
195 } // end namespace tensorflow
196