• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
17 #include <deque>
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/grappler/costs/graph_properties.h"
20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
21 #include "tensorflow/core/grappler/costs/virtual_placer.h"
22 #include "tensorflow/core/grappler/op_types.h"
23 #include "tensorflow/core/grappler/utils.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
PredictExecutionTime(const GraphProperties & properties,const OpLevelCostEstimator & estimator,const VirtualPlacer & placer,const NodeDef & node)30 static Costs::NanoSeconds PredictExecutionTime(
31     const GraphProperties& properties, const OpLevelCostEstimator& estimator,
32     const VirtualPlacer& placer, const NodeDef& node) {
33   OpContext op_context;
34   op_context.op_info.set_op(node.op());
35   *op_context.op_info.mutable_attr() = node.attr();
36 
37   std::vector<OpInfo::TensorProperties> inputs =
38       properties.GetInputProperties(node.name());
39   for (auto& input : inputs) {
40     op_context.op_info.add_inputs()->Swap(&input);
41   }
42 
43   std::vector<OpInfo::TensorProperties> outputs =
44       properties.GetOutputProperties(node.name());
45   for (auto& output : outputs) {
46     op_context.op_info.add_outputs()->Swap(&output);
47   }
48 
49   DeviceProperties device = placer.get_device(node);
50   op_context.op_info.mutable_device()->Swap(&device);
51 
52   Costs::NanoSeconds estimate =
53       estimator.PredictCosts(op_context).execution_time;
54 
55   // Make sure our estimates are at least one nanosecond per node.
56   return std::max(estimate, Costs::NanoSeconds(1));
57 }
58 
EstimateEarliestExecutionTimes(const GrapplerItem & item,const Cluster * cluster,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * completion_times)59 Status EstimateEarliestExecutionTimes(
60     const GrapplerItem& item, const Cluster* cluster,
61     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) {
62   std::unordered_map<string, const NodeDef*> name_map;
63   std::unordered_map<const NodeDef*, int> pending_inputs;
64   std::deque<const NodeDef*> ready_nodes;
65   for (const NodeDef& node : item.graph.node()) {
66     name_map[node.name()] = &node;
67     if (node.input_size() == 0) {
68       ready_nodes.push_back(&node);
69       (*completion_times)[&node] = 0;
70     } else if (IsMerge(node)) {
71       // Merge nodes are processed as soon as one of the input becomes
72       // available.
73       pending_inputs[&node] = 1;
74     } else {
75       pending_inputs[&node] = node.input_size();
76     }
77   }
78 
79   std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts;
80   for (const NodeDef& node : item.graph.node()) {
81     for (const string& input : node.input()) {
82       string node_name = NodeName(input);
83       auto it = name_map.find(node_name);
84       if (it == name_map.end()) {
85         return errors::InvalidArgument(
86             strings::StrCat("Unknown input node ", input));
87       }
88       const NodeDef* fanin = it->second;
89       fanouts[fanin].push_back(&node);
90     }
91   }
92   name_map.clear();
93 
94   GraphProperties properties(item);
95   TF_RETURN_IF_ERROR(properties.InferStatically(true));
96   OpLevelCostEstimator estimator;
97   VirtualPlacer placer(cluster);
98 
99   while (!ready_nodes.empty()) {
100     const NodeDef* node = ready_nodes.front();
101     ready_nodes.pop_front();
102 
103     Costs::NanoSeconds execution_time =
104         PredictExecutionTime(properties, estimator, placer, *node);
105     Costs::NanoSeconds completion_time =
106         execution_time + (*completion_times)[node];
107     (*completion_times)[node] = completion_time;
108 
109     for (const NodeDef* fanout : fanouts[node]) {
110       int pending = pending_inputs[fanout];
111       if (pending == 0) {
112         // Already processed. Avoid going through loops more than once.
113         continue;
114       } else if (pending == 1) {
115         ready_nodes.push_back(fanout);
116       }
117       pending_inputs[fanout]--;
118 
119       Costs::NanoSeconds ready_time =
120           std::max(completion_time, (*completion_times)[fanout]);
121       (*completion_times)[fanout] = ready_time;
122     }
123   }
124 
125   return Status::OK();
126 }
127 
EstimateRequiredTimes(const GrapplerItem & item,const Cluster * cluster,const std::unordered_map<const NodeDef *,Costs::NanoSeconds> & execution_times,std::unordered_map<const NodeDef *,Costs::NanoSeconds> * required_times)128 Status EstimateRequiredTimes(
129     const GrapplerItem& item, const Cluster* cluster,
130     const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
131         execution_times,
132     std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
133   std::unordered_map<string, const NodeDef*> name_map;
134   for (const NodeDef& node : item.graph.node()) {
135     name_map[node.name()] = &node;
136     (*required_times)[&node] = Costs::NanoSeconds::max();
137   }
138 
139   std::unordered_map<const NodeDef*, int> pending_fanouts;
140   for (const NodeDef& node : item.graph.node()) {
141     for (const string& input : node.input()) {
142       string node_name = NodeName(input);
143       auto it = name_map.find(node_name);
144       if (it == name_map.end()) {
145         return errors::InvalidArgument(
146             strings::StrCat("Unknown input node ", input));
147       }
148       const NodeDef* fanin = it->second;
149       pending_fanouts[fanin] += 1;
150     }
151   }
152   std::deque<const NodeDef*> ready_nodes;
153   for (const NodeDef& node : item.graph.node()) {
154     if (pending_fanouts[&node] == 0) {
155       auto it = execution_times.find(&node);
156       if (it != execution_times.end()) {
157         (*required_times)[&node] = it->second;
158       }
159       ready_nodes.push_back(&node);
160     }
161   }
162   GraphProperties properties(item);
163   TF_RETURN_IF_ERROR(properties.InferStatically(true));
164   OpLevelCostEstimator estimator;
165   VirtualPlacer placer(cluster);
166 
167   while (!ready_nodes.empty()) {
168     const NodeDef* node = ready_nodes.front();
169     ready_nodes.pop_front();
170 
171     Costs::NanoSeconds execution_time =
172         PredictExecutionTime(properties, estimator, placer, *node);
173     Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
174 
175     for (const string& fanin_name : node->input()) {
176       const NodeDef* fanin = name_map[NodeName(fanin_name)];
177       (*required_times)[fanin] =
178           std::min((*required_times)[fanin], required_time);
179 
180       int pending = pending_fanouts[fanin];
181       if (pending == 0) {
182         // Already processed. Avoid going through loops more than once.
183         continue;
184       } else if (pending == 1) {
185         ready_nodes.push_back(fanin);
186       }
187       pending_fanouts[fanin]--;
188     }
189   }
190 
191   return Status::OK();
192 }
193 
194 }  // end namespace grappler
195 }  // end namespace tensorflow
196