1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17
18 #include "tensorflow/core/framework/device_base.h"
19 #include "tensorflow/core/framework/op_def.pb.h"
20 #include "tensorflow/core/lib/gtl/map_util.h"
21 #include "tensorflow/core/util/ptr_util.h"
22
23 namespace tensorflow {
24 namespace grappler {
25 namespace graph_utils {
26 namespace {
27
28 constexpr char kConstOpName[] = "Const";
29
30 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)31 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
32 const Collection& collection) {
33 std::vector<int> indices = {};
34 unsigned idx = 0;
35 for (auto&& element : collection) {
36 if (predicate(element)) {
37 indices.push_back(idx);
38 }
39 idx++;
40 }
41 return indices;
42 }
43
CreateNameIndex(const GraphDef & graph)44 std::vector<int> CreateNameIndex(const GraphDef& graph) {
45 std::map<string, int> names;
46 for (int i = 0; i < graph.node_size(); ++i) {
47 names[graph.node(i).name()] = i;
48 }
49 std::vector<int> index(graph.node_size());
50 int i = 0;
51 for (const auto& pair : names) {
52 index[i++] = pair.second;
53 }
54 return index;
55 }
56
CreateInputIndex(const NodeDef & node)57 std::vector<int> CreateInputIndex(const NodeDef& node) {
58 std::map<string, int> inputs;
59 for (int i = 0; i < node.input_size(); ++i) {
60 inputs[node.input(i)] = i;
61 }
62 std::vector<int> index(node.input_size());
63 int i = 0;
64 for (const auto& pair : inputs) {
65 index[i++] = pair.second;
66 }
67 return index;
68 }
69
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)70 NodeDef* AddScalarConstNodeHelper(
71 DataType dtype, const std::function<void(TensorProto*)>& add_value,
72 MutableGraphView* graph) {
73 NodeDef node;
74 node.set_op(kConstOpName);
75 SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
76
77 (*node.mutable_attr())["dtype"].set_type(dtype);
78 std::unique_ptr<tensorflow::TensorProto> tensor =
79 tensorflow::MakeUnique<tensorflow::TensorProto>();
80 std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
81 tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
82 tensor->set_allocated_tensor_shape(tensor_shape.release());
83 tensor->set_dtype(dtype);
84 add_value(tensor.get());
85 (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
86
87 return graph->AddNode(std::move(node));
88 }
89
90 } // namespace
91
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)92 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
93 NodeDef node;
94 node.set_op("Placeholder");
95 SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
96 (*node.mutable_attr())["dtype"].set_type(dtype);
97 TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
98 shape->set_unknown_rank(false);
99 return graph->AddNode(std::move(node));
100 }
101
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)102 NodeDef* AddNode(StringPiece name, StringPiece op,
103 const std::vector<string>& inputs,
104 const std::vector<std::pair<string, AttrValue>>& attributes,
105 MutableGraphView* graph) {
106 NodeDef node;
107 if (!name.empty()) {
108 node.set_name(string(name));
109 } else {
110 SetUniqueGraphNodeName(op, graph->graph(), &node);
111 }
112 node.set_op(string(op));
113 for (const string& input : inputs) {
114 node.add_input(input);
115 }
116 for (auto attr : attributes) {
117 (*node.mutable_attr())[attr.first] = attr.second;
118 }
119 return graph->AddNode(std::move(node));
120 }
121
122 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)123 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
124 return AddScalarConstNodeHelper(
125 DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
126 }
127
128 template <>
AddScalarConstNode(double v,MutableGraphView * graph)129 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
130 return AddScalarConstNodeHelper(
131 DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
132 }
133
134 template <>
AddScalarConstNode(float v,MutableGraphView * graph)135 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
136 return AddScalarConstNodeHelper(
137 DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
138 }
139
140 template <>
AddScalarConstNode(int v,MutableGraphView * graph)141 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
142 return AddScalarConstNodeHelper(
143 DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
144 }
145
146 template <>
AddScalarConstNode(int64 v,MutableGraphView * graph)147 NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph) {
148 return AddScalarConstNodeHelper(
149 DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
150 }
151
152 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)153 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
154 return AddScalarConstNodeHelper(
155 DT_STRING,
156 [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
157 graph);
158 }
159
Compare(const GraphDef & g1,const GraphDef & g2)160 bool Compare(const GraphDef& g1, const GraphDef& g2) {
161 if (g1.node_size() != g2.node_size()) {
162 return false;
163 }
164 std::vector<int> name_index1 = CreateNameIndex(g1);
165 std::vector<int> name_index2 = CreateNameIndex(g2);
166 for (int i = 0; i < g1.node_size(); ++i) {
167 int idx1 = name_index1[i];
168 int idx2 = name_index2[i];
169 if (g1.node(idx1).op() != g2.node(idx2).op()) {
170 return false;
171 }
172 if (g1.node(idx1).name() != g2.node(idx2).name()) {
173 return false;
174 }
175 if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
176 return false;
177 }
178 std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
179 std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
180 for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
181 if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
182 g2.node(idx2).input(input_index2[j]))) {
183 return false;
184 }
185 }
186 }
187 return true;
188 }
189
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)190 bool ContainsGraphFunctionWithName(StringPiece name,
191 const FunctionDefLibrary& library) {
192 return FindGraphFunctionWithName(name, library) != -1;
193 }
194
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)195 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
196 return FindGraphNodeWithName(name, graph) != -1;
197 }
198
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)199 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
200 return FindGraphNodeWithOp(op, graph) != -1;
201 }
202
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)203 int FindGraphFunctionWithName(StringPiece name,
204 const FunctionDefLibrary& library) {
205 return GetFirstElementIndexWithPredicate(
206 [&name](const FunctionDef& function) {
207 return function.signature().name() == name;
208 },
209 library.function());
210 }
211
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)212 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
213 return GetFirstElementIndexWithPredicate(
214 [&name](const NodeDef& node) { return node.name() == name; },
215 graph.node());
216 }
217
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)218 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
219 return GetFirstElementIndexWithPredicate(
220 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
221 }
222
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)223 std::vector<int> FindAllGraphNodesWithOp(const string& op,
224 const GraphDef& graph) {
225 return GetElementIndicesWithPredicate(
226 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
227 }
228
GetInputNode(const NodeDef & node,const MutableGraphView & graph)229 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
230 if (node.input_size() == 0) return nullptr;
231 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
232 return graph.GetRegularFanin(input_port).node;
233 }
234
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64 i)235 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
236 int64 i) {
237 if (node.input_size() <= i) return nullptr;
238 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
239 return graph.GetRegularFanin(input_port).node;
240 }
241
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)242 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
243 NodeDef* node) {
244 string name = string(prefix);
245 int id = graph->node_size();
246 while (ContainsGraphNodeWithName(name, *graph)) {
247 if (name.rfind("_generated") != string::npos &&
248 (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
249 name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
250 } else {
251 name = strings::StrCat(prefix, "/_", id);
252 }
253 ++id;
254 }
255 node->set_name(std::move(name));
256 }
257
SetUniqueGraphFunctionName(StringPiece prefix,FunctionDefLibrary * library,FunctionDef * function)258 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
259 FunctionDef* function) {
260 string name = string(prefix);
261 int id = library->function_size();
262 while (ContainsGraphFunctionWithName(name, *library)) {
263 name = strings::StrCat(prefix, "/_", id);
264 ++id;
265 }
266 function->mutable_signature()->set_name(std::move(name));
267 }
268
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)269 void CopyAttribute(const string& attribute_name, const NodeDef& from,
270 NodeDef* to_node) {
271 (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
272 }
273
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)274 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
275 const NodeDef& second, NodeDef* to_node) {
276 CopyAttribute(attribute_name, first, to_node);
277 (*to_node->mutable_attr())
278 .at(attribute_name)
279 .mutable_list()
280 ->MergeFrom(second.attr().at(attribute_name).list());
281 }
282
EnsureNodeNamesUnique(Graph * g)283 Status EnsureNodeNamesUnique(Graph* g) {
284 // Modeled after Scope::Impl::GetUniqueName
285 std::unordered_map<string, int> name_map;
286
287 for (auto node : g->op_nodes()) {
288 const string& prefix = node->name();
289 if (auto entry = gtl::FindOrNull(name_map, prefix)) {
290 string unique_name;
291 do {
292 unique_name = strings::StrCat(prefix, "_", ++(*entry));
293 } while (name_map.find(unique_name) != name_map.end());
294 name_map.insert({unique_name, 0});
295 node->set_name(std::move(unique_name));
296 } else {
297 name_map.insert({node->name(), 0});
298 }
299 }
300
301 return Status::OK();
302 }
303
304 // Tries to find a Sink node in the graph. A sink node is defined as a node
305 // that has at least one input and no outputs. If there are multiple of these,
306 // this might return any one of them. This is useful to identify the final
307 // Dataset op in the graph but in some cases there might be multiple Identity
308 // ops added to the end and this would return the last Identity op in that case.
309
FindSinkNode(const GraphDef & graph_def,NodeDef * sink_node)310 Status FindSinkNode(const GraphDef& graph_def, NodeDef* sink_node) {
311 absl::flat_hash_map<string, int> all_node_names;
312 absl::flat_hash_map<string, int> node_input_map;
313 for (int i = 0; i < graph_def.node_size(); ++i) {
314 all_node_names.insert_or_assign(graph_def.node(i).name(), i);
315 node_input_map.insert_or_assign(graph_def.node(i).name(), 0);
316 }
317 // Counts how many graph nodes for each input name. Candidate sink
318 // nodes are ones which are inputs into zero nodes.
319 for (const NodeDef& node : graph_def.node()) {
320 for (const string& input_name : node.input()) {
321 node_input_map[input_name]++;
322 }
323 }
324 for (const auto& it : node_input_map) {
325 if (it.second == 0) {
326 const NodeDef& sink_graph_node = graph_def.node(all_node_names[it.first]);
327 if (sink_graph_node.input_size() == 0) {
328 continue;
329 }
330 *sink_node = sink_graph_node;
331 return Status::OK();
332 }
333 }
334 return errors::InvalidArgument("Failed to find a sink node");
335 }
336
337 } // namespace graph_utils
338 } // namespace grappler
339 } // namespace tensorflow
340