• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/grappler/costs/graph_properties.h"
22 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
PredictExecutionTime(const GraphProperties & properties,const OpLevelCostEstimator & estimator,const VirtualPlacer & placer,const NodeDef & node)32 static Costs::NanoSeconds PredictExecutionTime(
33     const GraphProperties& properties, const OpLevelCostEstimator& estimator,
34     const VirtualPlacer& placer, const NodeDef& node) {
35   OpContext op_context;
36   op_context.op_info.set_op(node.op());
37   *op_context.op_info.mutable_attr() = node.attr();
38 
39   std::vector<OpInfo::TensorProperties> inputs =
40       properties.GetInputProperties(node.name());
41   for (auto& input : inputs) {
42     op_context.op_info.add_inputs()->Swap(&input);
43   }
44 
45   std::vector<OpInfo::TensorProperties> outputs =
46       properties.GetOutputProperties(node.name());
47   for (auto& output : outputs) {
48     op_context.op_info.add_outputs()->Swap(&output);
49   }
50 
51   DeviceProperties device = placer.get_device(node);
52   op_context.op_info.mutable_device()->Swap(&device);
53 
54   Costs::NanoSeconds estimate =
55       estimator.PredictCosts(op_context).execution_time;
56 
57   // Make sure our estimates are at least one nanosecond per node.
58   return std::max(estimate, Costs::NanoSeconds(1));
59 }
60 
EstimateEarliestExecutionTimes(const GrapplerItem & item,const Cluster * cluster,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * completion_times)61 Status EstimateEarliestExecutionTimes(
62     const GrapplerItem& item, const Cluster* cluster,
63     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) {
64   std::unordered_map<string, const NodeDef*> name_map;
65   std::unordered_map<const NodeDef*, int> pending_inputs;
66   std::deque<const NodeDef*> ready_nodes;
67   for (const NodeDef& node : item.graph.node()) {
68     name_map[node.name()] = &node;
69     if (node.input_size() == 0) {
70       ready_nodes.push_back(&node);
71       (*completion_times)[&node] = 0;
72     } else if (IsMerge(node)) {
73       // Merge nodes are processed as soon as one of the input becomes
74       // available.
75       pending_inputs[&node] = 1;
76     } else {
77       pending_inputs[&node] = node.input_size();
78     }
79   }
80 
81   std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts;
82   for (const NodeDef& node : item.graph.node()) {
83     for (const string& input : node.input()) {
84       string node_name = NodeName(input);
85       auto it = name_map.find(node_name);
86       if (it == name_map.end()) {
87         return errors::InvalidArgument(
88             strings::StrCat("Unknown input node ", input));
89       }
90       const NodeDef* fanin = it->second;
91       fanouts[fanin].push_back(&node);
92     }
93   }
94   name_map.clear();
95 
96   GraphProperties properties(item);
97   TF_RETURN_IF_ERROR(
98       properties.InferStatically(/*assume_valid_feeds=*/true,
99                                  /*aggressive_shape_inference=*/false,
100                                  /*include_tensor_values=*/false));
101   OpLevelCostEstimator estimator;
102   VirtualPlacer placer(cluster->GetDevices());
103 
104   while (!ready_nodes.empty()) {
105     const NodeDef* node = ready_nodes.front();
106     ready_nodes.pop_front();
107 
108     Costs::NanoSeconds execution_time =
109         PredictExecutionTime(properties, estimator, placer, *node);
110     Costs::NanoSeconds completion_time =
111         execution_time + (*completion_times)[node];
112     (*completion_times)[node] = completion_time;
113 
114     for (const NodeDef* fanout : fanouts[node]) {
115       int pending = pending_inputs[fanout];
116       if (pending == 0) {
117         // Already processed. Avoid going through loops more than once.
118         continue;
119       } else if (pending == 1) {
120         ready_nodes.push_back(fanout);
121       }
122       pending_inputs[fanout]--;
123 
124       Costs::NanoSeconds ready_time =
125           std::max(completion_time, (*completion_times)[fanout]);
126       (*completion_times)[fanout] = ready_time;
127     }
128   }
129 
130   return OkStatus();
131 }
132 
EstimateRequiredTimes(const GrapplerItem & item,const Cluster * cluster,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * required_times)133 Status EstimateRequiredTimes(
134     const GrapplerItem& item, const Cluster* cluster,
135     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
136         execution_times,
137     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
138   std::unordered_map<string, const NodeDef*> name_map;
139   for (const NodeDef& node : item.graph.node()) {
140     name_map[node.name()] = &node;
141     (*required_times)[&node] = Costs::NanoSeconds::max();
142   }
143 
144   std::unordered_map<const NodeDef*, int> pending_fanouts;
145   for (const NodeDef& node : item.graph.node()) {
146     for (const string& input : node.input()) {
147       string node_name = NodeName(input);
148       auto it = name_map.find(node_name);
149       if (it == name_map.end()) {
150         return errors::InvalidArgument(
151             strings::StrCat("Unknown input node ", input));
152       }
153       const NodeDef* fanin = it->second;
154       pending_fanouts[fanin] += 1;
155     }
156   }
157   std::deque<const NodeDef*> ready_nodes;
158   for (const NodeDef& node : item.graph.node()) {
159     if (pending_fanouts[&node] == 0) {
160       auto it = execution_times.find(&node);
161       if (it != execution_times.end()) {
162         (*required_times)[&node] = it->second;
163       }
164       ready_nodes.push_back(&node);
165     }
166   }
167   GraphProperties properties(item);
168   TF_RETURN_IF_ERROR(
169       properties.InferStatically(/*assume_valid_feeds=*/true,
170                                  /*aggressive_shape_inference=*/false,
171                                  /*include_tensor_values=*/false));
172   OpLevelCostEstimator estimator;
173   VirtualPlacer placer(cluster->GetDevices());
174 
175   while (!ready_nodes.empty()) {
176     const NodeDef* node = ready_nodes.front();
177     ready_nodes.pop_front();
178 
179     Costs::NanoSeconds execution_time =
180         PredictExecutionTime(properties, estimator, placer, *node);
181     Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
182 
183     for (const string& fanin_name : node->input()) {
184       const NodeDef* fanin = name_map[NodeName(fanin_name)];
185       (*required_times)[fanin] =
186           std::min((*required_times)[fanin], required_time);
187 
188       int pending = pending_fanouts[fanin];
189       if (pending == 0) {
190         // Already processed. Avoid going through loops more than once.
191         continue;
192       } else if (pending == 1) {
193         ready_nodes.push_back(fanin);
194       }
195       pending_fanouts[fanin]--;
196     }
197   }
198 
199   return OkStatus();
200 }
201 
202 }  // end namespace grappler
203 }  // end namespace tensorflow
204