1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/util/ptr_util.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_utils {
27 namespace {
28
29 constexpr char kConstOpName[] = "Const";
30 constexpr char kRetValOp[] = "_Retval";
31
32 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)33 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
34 const Collection& collection) {
35 std::vector<int> indices = {};
36 unsigned idx = 0;
37 for (auto&& element : collection) {
38 if (predicate(element)) {
39 indices.push_back(idx);
40 }
41 idx++;
42 }
43 return indices;
44 }
45
CreateNameIndex(const GraphDef & graph)46 std::vector<int> CreateNameIndex(const GraphDef& graph) {
47 std::map<string, int> names;
48 for (int i = 0; i < graph.node_size(); ++i) {
49 names[graph.node(i).name()] = i;
50 }
51 std::vector<int> index(graph.node_size());
52 int i = 0;
53 for (const auto& pair : names) {
54 index[i++] = pair.second;
55 }
56 return index;
57 }
58
CreateInputIndex(const NodeDef & node)59 std::vector<int> CreateInputIndex(const NodeDef& node) {
60 std::map<string, int> inputs;
61 for (int i = 0; i < node.input_size(); ++i) {
62 inputs[node.input(i)] = i;
63 }
64 std::vector<int> index(node.input_size());
65 int i = 0;
66 for (const auto& pair : inputs) {
67 index[i++] = pair.second;
68 }
69 return index;
70 }
71
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)72 NodeDef* AddScalarConstNodeHelper(
73 DataType dtype, const std::function<void(TensorProto*)>& add_value,
74 MutableGraphView* graph) {
75 NodeDef node;
76 node.set_op(kConstOpName);
77 SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
78
79 (*node.mutable_attr())["dtype"].set_type(dtype);
80 std::unique_ptr<tensorflow::TensorProto> tensor =
81 tensorflow::MakeUnique<tensorflow::TensorProto>();
82 std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
83 tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
84 tensor->set_allocated_tensor_shape(tensor_shape.release());
85 tensor->set_dtype(dtype);
86 add_value(tensor.get());
87 (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
88
89 return graph->AddNode(std::move(node));
90 }
91
92 } // namespace
93
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)94 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
95 NodeDef node;
96 node.set_op("Placeholder");
97 SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
98 (*node.mutable_attr())["dtype"].set_type(dtype);
99 TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
100 shape->set_unknown_rank(false);
101 return graph->AddNode(std::move(node));
102 }
103
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)104 NodeDef* AddNode(StringPiece name, StringPiece op,
105 const std::vector<string>& inputs,
106 const std::vector<std::pair<string, AttrValue>>& attributes,
107 MutableGraphView* graph) {
108 NodeDef node;
109 if (!name.empty()) {
110 node.set_name(string(name));
111 } else {
112 SetUniqueGraphNodeName(op, graph->graph(), &node);
113 }
114 node.set_op(string(op));
115 for (const string& input : inputs) {
116 node.add_input(input);
117 }
118 for (const auto& attr : attributes) {
119 (*node.mutable_attr())[attr.first] = attr.second;
120 }
121 return graph->AddNode(std::move(node));
122 }
123
124 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)125 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
126 return AddScalarConstNodeHelper(
127 DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
128 }
129
130 template <>
AddScalarConstNode(double v,MutableGraphView * graph)131 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
132 return AddScalarConstNodeHelper(
133 DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
134 }
135
136 template <>
AddScalarConstNode(float v,MutableGraphView * graph)137 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
138 return AddScalarConstNodeHelper(
139 DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
140 }
141
142 template <>
AddScalarConstNode(int v,MutableGraphView * graph)143 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
144 return AddScalarConstNodeHelper(
145 DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
146 }
147
148 template <>
AddScalarConstNode(int64 v,MutableGraphView * graph)149 NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) {
150 return AddScalarConstNodeHelper(
151 DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
152 }
153
154 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)155 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
156 return AddScalarConstNodeHelper(
157 DT_STRING,
158 [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
159 graph);
160 }
161
GetScalarConstNodeValueHelper(const NodeDef & node,DataType dtype,const std::function<void (const Tensor &)> & get_value)162 Status GetScalarConstNodeValueHelper(
163 const NodeDef& node, DataType dtype,
164 const std::function<void(const Tensor&)>& get_value) {
165 if (node.op() != kConstOpName)
166 return errors::InvalidArgument("Node ", node.name(),
167 " is not a Const node. Op: ", node.op());
168
169 Tensor tensor;
170 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
171 if (!TensorShapeUtils::IsScalar(tensor.shape())) {
172 return errors::InvalidArgument(
173 "Node ", node.name(),
174 " should be a scalar but has shape: ", tensor.shape());
175 }
176
177 if (tensor.dtype() != dtype) {
178 return errors::InvalidArgument(
179 "Node ", node.name(), " should have type ", DataTypeString(dtype),
180 " but has type: ", DataTypeString(tensor.dtype()));
181 }
182
183 get_value(tensor);
184
185 return Status::OK();
186 }
187
188 template <>
GetScalarConstNodeValue(const NodeDef & node,int64 * value)189 Status GetScalarConstNodeValue(const NodeDef& node, int64* value) {
190 return GetScalarConstNodeValueHelper(
191 node, DT_INT64,
192 [value](const Tensor& tensor) { *value = tensor.scalar<int64>()(); });
193 }
194
195 template <>
GetScalarConstNodeValue(const NodeDef & node,bool * value)196 Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
197 return GetScalarConstNodeValueHelper(
198 node, DT_BOOL,
199 [value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
200 }
201
Compare(const GraphDef & g1,const GraphDef & g2)202 bool Compare(const GraphDef& g1, const GraphDef& g2) {
203 if (g1.node_size() != g2.node_size()) {
204 return false;
205 }
206 std::vector<int> name_index1 = CreateNameIndex(g1);
207 std::vector<int> name_index2 = CreateNameIndex(g2);
208 for (int i = 0; i < g1.node_size(); ++i) {
209 int idx1 = name_index1[i];
210 int idx2 = name_index2[i];
211 if (g1.node(idx1).op() != g2.node(idx2).op()) {
212 return false;
213 }
214 if (g1.node(idx1).name() != g2.node(idx2).name()) {
215 return false;
216 }
217 if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
218 return false;
219 }
220 std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
221 std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
222 for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
223 if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
224 g2.node(idx2).input(input_index2[j]))) {
225 return false;
226 }
227 }
228 }
229 return true;
230 }
231
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)232 bool ContainsGraphFunctionWithName(StringPiece name,
233 const FunctionDefLibrary& library) {
234 return FindGraphFunctionWithName(name, library) != -1;
235 }
236
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)237 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
238 return FindGraphNodeWithName(name, graph) != -1;
239 }
240
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)241 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
242 return FindGraphNodeWithOp(op, graph) != -1;
243 }
244
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)245 int FindGraphFunctionWithName(StringPiece name,
246 const FunctionDefLibrary& library) {
247 return GetFirstElementIndexWithPredicate(
248 [&name](const FunctionDef& function) {
249 return function.signature().name() == name;
250 },
251 library.function());
252 }
253
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)254 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
255 return GetFirstElementIndexWithPredicate(
256 [&name](const NodeDef& node) { return node.name() == name; },
257 graph.node());
258 }
259
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)260 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
261 return GetFirstElementIndexWithPredicate(
262 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
263 }
264
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)265 std::vector<int> FindAllGraphNodesWithOp(const string& op,
266 const GraphDef& graph) {
267 return GetElementIndicesWithPredicate(
268 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
269 }
270
GetInputNode(const NodeDef & node,const MutableGraphView & graph)271 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
272 if (node.input_size() == 0) return nullptr;
273 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
274 return graph.GetRegularFanin(input_port).node;
275 }
276
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64 i)277 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
278 int64 i) {
279 if (node.input_size() <= i) return nullptr;
280 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
281 return graph.GetRegularFanin(input_port).node;
282 }
283
GetDatasetOutputTypesAttr(const NodeDef & node,DataTypeVector * output_types)284 Status GetDatasetOutputTypesAttr(const NodeDef& node,
285 DataTypeVector* output_types) {
286 // We don't name the output_types attr consistently, so should check for both.
287 for (const string& attr_name : {"output_types", "Toutput_types"}) {
288 if (node.attr().contains(attr_name)) {
289 return GetNodeAttr(node, attr_name, output_types);
290 }
291 }
292 return errors::InvalidArgument("Could not find output_types attr for node: ",
293 node.name(), " with op: ", node.op());
294 }
295
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)296 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
297 NodeDef* node) {
298 string name = string(prefix);
299 int id = graph->node_size();
300 while (ContainsGraphNodeWithName(name, *graph)) {
301 if (name.rfind("_generated") != string::npos &&
302 (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
303 name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
304 } else {
305 name = strings::StrCat(prefix, "/_", id);
306 }
307 ++id;
308 }
309 node->set_name(std::move(name));
310 }
311
SetUniqueGraphFunctionName(StringPiece prefix,FunctionDefLibrary * library,FunctionDef * function)312 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
313 FunctionDef* function) {
314 string name = string(prefix);
315 int id = library->function_size();
316 while (ContainsGraphFunctionWithName(name, *library)) {
317 name = strings::StrCat(prefix, "/_", id);
318 ++id;
319 }
320 function->mutable_signature()->set_name(std::move(name));
321 }
322
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)323 void CopyAttribute(const string& attribute_name, const NodeDef& from,
324 NodeDef* to_node) {
325 (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
326 }
327
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)328 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
329 const NodeDef& second, NodeDef* to_node) {
330 CopyAttribute(attribute_name, first, to_node);
331 (*to_node->mutable_attr())
332 .at(attribute_name)
333 .mutable_list()
334 ->MergeFrom(second.attr().at(attribute_name).list());
335 }
336
EnsureNodeNamesUnique(Graph * g)337 Status EnsureNodeNamesUnique(Graph* g) {
338 // Modeled after Scope::Impl::GetUniqueName
339 std::unordered_map<string, int> name_map;
340
341 for (auto node : g->op_nodes()) {
342 const string& prefix = node->name();
343 if (auto entry = gtl::FindOrNull(name_map, prefix)) {
344 string unique_name;
345 do {
346 unique_name = strings::StrCat(prefix, "_", ++(*entry));
347 } while (name_map.find(unique_name) != name_map.end());
348 name_map.insert({unique_name, 0});
349 node->set_name(std::move(unique_name));
350 } else {
351 name_map.insert({node->name(), 0});
352 }
353 }
354
355 return Status::OK();
356 }
357
GetFetchNode(const MutableGraphView & graph,const GrapplerItem & item,NodeDef ** fetch_node)358 Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item,
359 NodeDef** fetch_node) {
360 if (item.fetch.size() != 1) {
361 return errors::InvalidArgument(
362 "Expected only one fetch node but there were ", item.fetch.size(), ": ",
363 absl::StrJoin(item.fetch, ", "));
364 }
365
366 *fetch_node = graph.GetNode(item.fetch.at(0));
367
368 return Status::OK();
369 }
370
IsItemDerivedFromFunctionDef(const GrapplerItem & item,const MutableGraphView & graph_view)371 bool IsItemDerivedFromFunctionDef(const GrapplerItem& item,
372 const MutableGraphView& graph_view) {
373 for (const auto& fetch_name : item.fetch) {
374 auto fetch = graph_view.GetNode(fetch_name);
375 if (fetch != nullptr && fetch->op() != kRetValOp) {
376 // We found a fetch node which is not a `Retval` op.
377 return false;
378 }
379 }
380 // All fetch nodes are `Retval` ops (or we don't have any fetch nodes).
381 return true;
382 }
383
384 } // namespace graph_utils
385 } // namespace grappler
386 } // namespace tensorflow
387