• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17 
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/gtl/map_util.h"
21 #include "tensorflow/core/util/ptr_util.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 namespace graph_utils {
26 namespace {
27 
28 constexpr char kConstOpName[] = "Const";
29 
30 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)31 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
32                                                 const Collection& collection) {
33   std::vector<int> indices = {};
34   unsigned idx = 0;
35   for (auto&& element : collection) {
36     if (predicate(element)) {
37       indices.push_back(idx);
38     }
39     idx++;
40   }
41   return indices;
42 }
43 
CreateNameIndex(const GraphDef & graph)44 std::vector<int> CreateNameIndex(const GraphDef& graph) {
45   std::map<string, int> names;
46   for (int i = 0; i < graph.node_size(); ++i) {
47     names[graph.node(i).name()] = i;
48   }
49   std::vector<int> index(graph.node_size());
50   int i = 0;
51   for (const auto& pair : names) {
52     index[i++] = pair.second;
53   }
54   return index;
55 }
56 
CreateInputIndex(const NodeDef & node)57 std::vector<int> CreateInputIndex(const NodeDef& node) {
58   std::map<string, int> inputs;
59   for (int i = 0; i < node.input_size(); ++i) {
60     inputs[node.input(i)] = i;
61   }
62   std::vector<int> index(node.input_size());
63   int i = 0;
64   for (const auto& pair : inputs) {
65     index[i++] = pair.second;
66   }
67   return index;
68 }
69 
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)70 NodeDef* AddScalarConstNodeHelper(
71     DataType dtype, const std::function<void(TensorProto*)>& add_value,
72     MutableGraphView* graph) {
73   NodeDef node;
74   node.set_op(kConstOpName);
75   SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
76 
77   (*node.mutable_attr())["dtype"].set_type(dtype);
78   std::unique_ptr<tensorflow::TensorProto> tensor =
79       tensorflow::MakeUnique<tensorflow::TensorProto>();
80   std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
81       tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
82   tensor->set_allocated_tensor_shape(tensor_shape.release());
83   tensor->set_dtype(dtype);
84   add_value(tensor.get());
85   (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
86 
87   return graph->AddNode(std::move(node));
88 }
89 
90 }  // namespace
91 
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)92 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
93   NodeDef node;
94   node.set_op("Placeholder");
95   SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
96   (*node.mutable_attr())["dtype"].set_type(dtype);
97   TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
98   shape->set_unknown_rank(false);
99   return graph->AddNode(std::move(node));
100 }
101 
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)102 NodeDef* AddNode(StringPiece name, StringPiece op,
103                  const std::vector<string>& inputs,
104                  const std::vector<std::pair<string, AttrValue>>& attributes,
105                  MutableGraphView* graph) {
106   NodeDef node;
107   if (!name.empty()) {
108     node.set_name(string(name));
109   } else {
110     SetUniqueGraphNodeName(op, graph->graph(), &node);
111   }
112   node.set_op(string(op));
113   for (const string& input : inputs) {
114     node.add_input(input);
115   }
116   for (auto attr : attributes) {
117     (*node.mutable_attr())[attr.first] = attr.second;
118   }
119   return graph->AddNode(std::move(node));
120 }
121 
122 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)123 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
124   return AddScalarConstNodeHelper(
125       DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
126 }
127 
128 template <>
AddScalarConstNode(double v,MutableGraphView * graph)129 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
130   return AddScalarConstNodeHelper(
131       DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
132 }
133 
134 template <>
AddScalarConstNode(float v,MutableGraphView * graph)135 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
136   return AddScalarConstNodeHelper(
137       DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
138 }
139 
140 template <>
AddScalarConstNode(int v,MutableGraphView * graph)141 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
142   return AddScalarConstNodeHelper(
143       DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
144 }
145 
146 template <>
AddScalarConstNode(int64 v,MutableGraphView * graph)147 NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) {
148   return AddScalarConstNodeHelper(
149       DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
150 }
151 
152 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)153 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
154   return AddScalarConstNodeHelper(
155       DT_STRING,
156       [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
157       graph);
158 }
159 
Compare(const GraphDef & g1,const GraphDef & g2)160 bool Compare(const GraphDef& g1, const GraphDef& g2) {
161   if (g1.node_size() != g2.node_size()) {
162     return false;
163   }
164   std::vector<int> name_index1 = CreateNameIndex(g1);
165   std::vector<int> name_index2 = CreateNameIndex(g2);
166   for (int i = 0; i < g1.node_size(); ++i) {
167     int idx1 = name_index1[i];
168     int idx2 = name_index2[i];
169     if (g1.node(idx1).op() != g2.node(idx2).op()) {
170       return false;
171     }
172     if (g1.node(idx1).name() != g2.node(idx2).name()) {
173       return false;
174     }
175     if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
176       return false;
177     }
178     std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
179     std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
180     for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
181       if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
182                        g2.node(idx2).input(input_index2[j]))) {
183         return false;
184       }
185     }
186   }
187   return true;
188 }
189 
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)190 bool ContainsGraphFunctionWithName(StringPiece name,
191                                    const FunctionDefLibrary& library) {
192   return FindGraphFunctionWithName(name, library) != -1;
193 }
194 
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)195 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
196   return FindGraphNodeWithName(name, graph) != -1;
197 }
198 
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)199 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
200   return FindGraphNodeWithOp(op, graph) != -1;
201 }
202 
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)203 int FindGraphFunctionWithName(StringPiece name,
204                               const FunctionDefLibrary& library) {
205   return GetFirstElementIndexWithPredicate(
206       [&name](const FunctionDef& function) {
207         return function.signature().name() == name;
208       },
209       library.function());
210 }
211 
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)212 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
213   return GetFirstElementIndexWithPredicate(
214       [&name](const NodeDef& node) { return node.name() == name; },
215       graph.node());
216 }
217 
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)218 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
219   return GetFirstElementIndexWithPredicate(
220       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
221 }
222 
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)223 std::vector<int> FindAllGraphNodesWithOp(const string& op,
224                                          const GraphDef& graph) {
225   return GetElementIndicesWithPredicate(
226       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
227 }
228 
GetInputNode(const NodeDef & node,const MutableGraphView & graph)229 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
230   if (node.input_size() == 0) return nullptr;
231   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
232   return graph.GetRegularFanin(input_port).node;
233 }
234 
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64 i)235 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
236                       int64 i) {
237   if (node.input_size() <= i) return nullptr;
238   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
239   return graph.GetRegularFanin(input_port).node;
240 }
241 
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)242 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
243                             NodeDef* node) {
244   string name = string(prefix);
245   int id = graph->node_size();
246   while (ContainsGraphNodeWithName(name, *graph)) {
247     if (name.rfind("_generated") != string::npos &&
248         (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
249       name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
250     } else {
251       name = strings::StrCat(prefix, "/_", id);
252     }
253     ++id;
254   }
255   node->set_name(std::move(name));
256 }
257 
SetUniqueGraphFunctionName(StringPiece prefix,FunctionDefLibrary * library,FunctionDef * function)258 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
259                                 FunctionDef* function) {
260   string name = string(prefix);
261   int id = library->function_size();
262   while (ContainsGraphFunctionWithName(name, *library)) {
263     name = strings::StrCat(prefix, "/_", id);
264     ++id;
265   }
266   function->mutable_signature()->set_name(std::move(name));
267 }
268 
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)269 void CopyAttribute(const string& attribute_name, const NodeDef& from,
270                    NodeDef* to_node) {
271   (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
272 }
273 
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)274 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
275                          const NodeDef& second, NodeDef* to_node) {
276   CopyAttribute(attribute_name, first, to_node);
277   (*to_node->mutable_attr())
278       .at(attribute_name)
279       .mutable_list()
280       ->MergeFrom(second.attr().at(attribute_name).list());
281 }
282 
EnsureNodeNamesUnique(Graph * g)283 Status EnsureNodeNamesUnique(Graph* g) {
284   // Modeled after Scope::Impl::GetUniqueName
285   std::unordered_map<string, int> name_map;
286 
287   for (auto node : g->op_nodes()) {
288     const string& prefix = node->name();
289     if (auto entry = gtl::FindOrNull(name_map, prefix)) {
290       string unique_name;
291       do {
292         unique_name = strings::StrCat(prefix, "_", ++(*entry));
293       } while (name_map.find(unique_name) != name_map.end());
294       name_map.insert({unique_name, 0});
295       node->set_name(std::move(unique_name));
296     } else {
297       name_map.insert({node->name(), 0});
298     }
299   }
300 
301   return Status::OK();
302 }
303 
304 // Tries to find a Sink node in the graph. A sink node is defined as a node
305 // that has at least one input and no outputs. If there are multiple of these,
306 // this might return any one of them. This is useful to identify the final
307 // Dataset op in the graph but in some cases there might be multiple Identity
308 // ops added to the end and this would return the last Identity op in that case.
309 
FindSinkNode(const GraphDef & graph_def,NodeDef * sink_node)310 Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
311   absl::flat_hash_map<string, int> all_node_names;
312   absl::flat_hash_map<string, int> node_input_map;
313   for (int i = 0; i < graph_def.node_size(); ++i) {
314     all_node_names.insert_or_assign(graph_def.node(i).name(), i);
315     node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
316   }
317   // Counts how many graph nodes for each input name. Candidate sink
318   // nodes are ones which are inputs into zero nodes.
319   for (const NodeDef& node : graph_def.node()) {
320     for (const string& input_name : node.input()) {
321       node_input_map[input_name]++;
322     }
323   }
324   for (const auto& it : node_input_map) {
325     if (it.second == 0) {
326       const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
327       if (sink_graph_node.input_size() == 0) {
328         continue;
329       }
330       *sink_node = sink_graph_node;
331       return Status::OK();
332     }
333   }
334   return errors::InvalidArgument("Failed to find a sink node");
335 }
336 
337 }  // namespace graph_utils
338 }  // namespace grappler
339 }  // namespace tensorflow
340