• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17 
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/util/ptr_util.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_utils {
27 namespace {
28 
29 constexpr char kConstOpName[] = "Const";
30 constexpr char kRetValOp[] = "_Retval";
31 
32 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)33 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
34                                                 const Collection& collection) {
35   std::vector<int> indices = {};
36   unsigned idx = 0;
37   for (auto&& element : collection) {
38     if (predicate(element)) {
39       indices.push_back(idx);
40     }
41     idx++;
42   }
43   return indices;
44 }
45 
CreateNameIndex(const GraphDef & graph)46 std::vector<int> CreateNameIndex(const GraphDef& graph) {
47   std::map<string, int> names;
48   for (int i = 0; i < graph.node_size(); ++i) {
49     names[graph.node(i).name()] = i;
50   }
51   std::vector<int> index(graph.node_size());
52   int i = 0;
53   for (const auto& pair : names) {
54     index[i++] = pair.second;
55   }
56   return index;
57 }
58 
CreateInputIndex(const NodeDef & node)59 std::vector<int> CreateInputIndex(const NodeDef& node) {
60   std::map<string, int> inputs;
61   for (int i = 0; i < node.input_size(); ++i) {
62     inputs[node.input(i)] = i;
63   }
64   std::vector<int> index(node.input_size());
65   int i = 0;
66   for (const auto& pair : inputs) {
67     index[i++] = pair.second;
68   }
69   return index;
70 }
71 
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)72 NodeDef* AddScalarConstNodeHelper(
73     DataType dtype, const std::function<void(TensorProto*)>& add_value,
74     MutableGraphView* graph) {
75   NodeDef node;
76   node.set_op(kConstOpName);
77   SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
78 
79   (*node.mutable_attr())["dtype"].set_type(dtype);
80   std::unique_ptr<tensorflow::TensorProto> tensor =
81       tensorflow::MakeUnique<tensorflow::TensorProto>();
82   std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
83       tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
84   tensor->set_allocated_tensor_shape(tensor_shape.release());
85   tensor->set_dtype(dtype);
86   add_value(tensor.get());
87   (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
88 
89   return graph->AddNode(std::move(node));
90 }
91 
92 }  // namespace
93 
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)94 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
95   NodeDef node;
96   node.set_op("Placeholder");
97   SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
98   (*node.mutable_attr())["dtype"].set_type(dtype);
99   TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
100   shape->set_unknown_rank(false);
101   return graph->AddNode(std::move(node));
102 }
103 
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)104 NodeDef* AddNode(StringPiece name, StringPiece op,
105                  const std::vector<string>& inputs,
106                  const std::vector<std::pair<string, AttrValue>>& attributes,
107                  MutableGraphView* graph) {
108   NodeDef node;
109   if (!name.empty()) {
110     node.set_name(string(name));
111   } else {
112     SetUniqueGraphNodeName(op, graph->graph(), &node);
113   }
114   node.set_op(string(op));
115   for (const string& input : inputs) {
116     node.add_input(input);
117   }
118   for (const auto& attr : attributes) {
119     (*node.mutable_attr())[attr.first] = attr.second;
120   }
121   return graph->AddNode(std::move(node));
122 }
123 
124 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)125 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
126   return AddScalarConstNodeHelper(
127       DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
128 }
129 
130 template <>
AddScalarConstNode(double v,MutableGraphView * graph)131 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
132   return AddScalarConstNodeHelper(
133       DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
134 }
135 
136 template <>
AddScalarConstNode(float v,MutableGraphView * graph)137 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
138   return AddScalarConstNodeHelper(
139       DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
140 }
141 
142 template <>
AddScalarConstNode(int v,MutableGraphView * graph)143 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
144   return AddScalarConstNodeHelper(
145       DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
146 }
147 
148 template <>
AddScalarConstNode(int64 v,MutableGraphView * graph)149 NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) {
150   return AddScalarConstNodeHelper(
151       DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
152 }
153 
154 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)155 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
156   return AddScalarConstNodeHelper(
157       DT_STRING,
158       [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
159       graph);
160 }
161 
GetScalarConstNodeValueHelper(const NodeDef & node,DataType dtype,const std::function<void (const Tensor &)> & get_value)162 Status GetScalarConstNodeValueHelper(
163     const NodeDef& node, DataType dtype,
164     const std::function<void(const Tensor&)>& get_value) {
165   if (node.op() != kConstOpName)
166     return errors::InvalidArgument("Node ", node.name(),
167                                    " is not a Const node. Op: ", node.op());
168 
169   Tensor tensor;
170   TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
171   if (!TensorShapeUtils::IsScalar(tensor.shape())) {
172     return errors::InvalidArgument(
173         "Node ", node.name(),
174         " should be a scalar but has shape: ", tensor.shape());
175   }
176 
177   if (tensor.dtype() != dtype) {
178     return errors::InvalidArgument(
179         "Node ", node.name(), " should have type ", DataTypeString(dtype),
180         " but has type: ", DataTypeString(tensor.dtype()));
181   }
182 
183   get_value(tensor);
184 
185   return Status::OK();
186 }
187 
188 template <>
GetScalarConstNodeValue(const NodeDef & node,int64 * value)189 Status GetScalarConstNodeValue(const NodeDef& node, int64* value) {
190   return GetScalarConstNodeValueHelper(
191       node, DT_INT64,
192       [value](const Tensor& tensor) { *value = tensor.scalar<int64>()(); });
193 }
194 
195 template <>
GetScalarConstNodeValue(const NodeDef & node,bool * value)196 Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
197   return GetScalarConstNodeValueHelper(
198       node, DT_BOOL,
199       [value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
200 }
201 
Compare(const GraphDef & g1,const GraphDef & g2)202 bool Compare(const GraphDef& g1, const GraphDef& g2) {
203   if (g1.node_size() != g2.node_size()) {
204     return false;
205   }
206   std::vector<int> name_index1 = CreateNameIndex(g1);
207   std::vector<int> name_index2 = CreateNameIndex(g2);
208   for (int i = 0; i < g1.node_size(); ++i) {
209     int idx1 = name_index1[i];
210     int idx2 = name_index2[i];
211     if (g1.node(idx1).op() != g2.node(idx2).op()) {
212       return false;
213     }
214     if (g1.node(idx1).name() != g2.node(idx2).name()) {
215       return false;
216     }
217     if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
218       return false;
219     }
220     std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
221     std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
222     for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
223       if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
224                        g2.node(idx2).input(input_index2[j]))) {
225         return false;
226       }
227     }
228   }
229   return true;
230 }
231 
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)232 bool ContainsGraphFunctionWithName(StringPiece name,
233                                    const FunctionDefLibrary& library) {
234   return FindGraphFunctionWithName(name, library) != -1;
235 }
236 
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)237 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
238   return FindGraphNodeWithName(name, graph) != -1;
239 }
240 
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)241 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
242   return FindGraphNodeWithOp(op, graph) != -1;
243 }
244 
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)245 int FindGraphFunctionWithName(StringPiece name,
246                               const FunctionDefLibrary& library) {
247   return GetFirstElementIndexWithPredicate(
248       [&name](const FunctionDef& function) {
249         return function.signature().name() == name;
250       },
251       library.function());
252 }
253 
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)254 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
255   return GetFirstElementIndexWithPredicate(
256       [&name](const NodeDef& node) { return node.name() == name; },
257       graph.node());
258 }
259 
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)260 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
261   return GetFirstElementIndexWithPredicate(
262       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
263 }
264 
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)265 std::vector<int> FindAllGraphNodesWithOp(const string& op,
266                                          const GraphDef& graph) {
267   return GetElementIndicesWithPredicate(
268       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
269 }
270 
GetInputNode(const NodeDef & node,const MutableGraphView & graph)271 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
272   if (node.input_size() == 0) return nullptr;
273   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
274   return graph.GetRegularFanin(input_port).node;
275 }
276 
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64 i)277 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
278                       int64 i) {
279   if (node.input_size() <= i) return nullptr;
280   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
281   return graph.GetRegularFanin(input_port).node;
282 }
283 
GetDatasetOutputTypesAttr(const NodeDef & node,DataTypeVector * output_types)284 Status GetDatasetOutputTypesAttr(const NodeDef& node,
285                                  DataTypeVector* output_types) {
286   // We don't name the output_types attr consistently, so should check for both.
287   for (const string& attr_name : {"output_types", "Toutput_types"}) {
288     if (node.attr().contains(attr_name)) {
289       return GetNodeAttr(node, attr_name, output_types);
290     }
291   }
292   return errors::InvalidArgument("Could not find output_types attr for node: ",
293                                  node.name(), " with op: ", node.op());
294 }
295 
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)296 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
297                             NodeDef* node) {
298   string name = string(prefix);
299   int id = graph->node_size();
300   while (ContainsGraphNodeWithName(name, *graph)) {
301     if (name.rfind("_generated") != string::npos &&
302         (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
303       name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
304     } else {
305       name = strings::StrCat(prefix, "/_", id);
306     }
307     ++id;
308   }
309   node->set_name(std::move(name));
310 }
311 
SetUniqueGraphFunctionName(StringPiece prefix,FunctionDefLibrary * library,FunctionDef * function)312 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
313                                 FunctionDef* function) {
314   string name = string(prefix);
315   int id = library->function_size();
316   while (ContainsGraphFunctionWithName(name, *library)) {
317     name = strings::StrCat(prefix, "/_", id);
318     ++id;
319   }
320   function->mutable_signature()->set_name(std::move(name));
321 }
322 
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)323 void CopyAttribute(const string& attribute_name, const NodeDef& from,
324                    NodeDef* to_node) {
325   (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
326 }
327 
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)328 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
329                          const NodeDef& second, NodeDef* to_node) {
330   CopyAttribute(attribute_name, first, to_node);
331   (*to_node->mutable_attr())
332       .at(attribute_name)
333       .mutable_list()
334       ->MergeFrom(second.attr().at(attribute_name).list());
335 }
336 
EnsureNodeNamesUnique(Graph * g)337 Status EnsureNodeNamesUnique(Graph* g) {
338   // Modeled after Scope::Impl::GetUniqueName
339   std::unordered_map<string, int> name_map;
340 
341   for (auto node : g->op_nodes()) {
342     const string& prefix = node->name();
343     if (auto entry = gtl::FindOrNull(name_map, prefix)) {
344       string unique_name;
345       do {
346         unique_name = strings::StrCat(prefix, "_", ++(*entry));
347       } while (name_map.find(unique_name) != name_map.end());
348       name_map.insert({unique_name, 0});
349       node->set_name(std::move(unique_name));
350     } else {
351       name_map.insert({node->name(), 0});
352     }
353   }
354 
355   return Status::OK();
356 }
357 
GetFetchNode(const MutableGraphView & graph,const GrapplerItem & item,NodeDef ** fetch_node)358 Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item,
359                     NodeDef** fetch_node) {
360   if (item.fetch.size() != 1) {
361     return errors::InvalidArgument(
362         "Expected only one fetch node but there were ", item.fetch.size(), ": ",
363         absl::StrJoin(item.fetch, ", "));
364   }
365 
366   *fetch_node = graph.GetNode(item.fetch.at(0));
367 
368   return Status::OK();
369 }
370 
IsItemDerivedFromFunctionDef(const GrapplerItem & item,const MutableGraphView & graph_view)371 bool IsItemDerivedFromFunctionDef(const GrapplerItem& item,
372                                   const MutableGraphView& graph_view) {
373   for (const auto& fetch_name : item.fetch) {
374     auto fetch = graph_view.GetNode(fetch_name);
375     if (fetch != nullptr && fetch->op() != kRetValOp) {
376       // We found a fetch node which is not a `Retval` op.
377       return false;
378     }
379   }
380   // All fetch nodes are `Retval` ops (or we don't have any fetch nodes).
381   return true;
382 }
383 
384 }  // namespace graph_utils
385 }  // namespace grappler
386 }  // namespace tensorflow
387