1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/graph_def_util.h"
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27
28 namespace tensorflow {
29 namespace {
30
FinalizeOpDef(const OpDefBuilder & b,OpDef * op_def)31 Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
32 OpRegistrationData op_reg_data;
33 const Status s = b.Finalize(&op_reg_data);
34 *op_def = op_reg_data.op_def;
35 return s;
36 }
37
38 // We can create a Graph containing a namespaced Op
TEST(AddToGraphTest,MakeGraphDefWithNamespacedOpName)39 TEST(AddToGraphTest, MakeGraphDefWithNamespacedOpName) {
40 OpList op_list;
41 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("Project>SomeOp"), op_list.add_op()));
42 OpListOpRegistry registry(&op_list);
43
44 GraphDef graph_def;
45 TF_ASSERT_OK(NodeDefBuilder("node", "Project>SomeOp", ®istry)
46 .Finalize(graph_def.add_node()));
47 }
48
49 // Producer and consumer have default for an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeWithDefault)50 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
51 OpList op_list;
52 TF_ASSERT_OK(
53 FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
54 op_list.add_op()));
55 OpListOpRegistry registry(&op_list);
56
57 GraphDef graph_def;
58 TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry)
59 .Finalize(graph_def.add_node()));
60 GraphDef expected_graph_def = graph_def;
61
62 std::set<std::pair<string, string>> op_attr_removed;
63 TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
64 &op_attr_removed));
65
66 TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
67 EXPECT_TRUE(op_attr_removed.empty());
68 }
69
70 // Producer and consumer both have an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeNoDefault)71 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
72 OpList op_list;
73 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
74 op_list.add_op()));
75 OpListOpRegistry registry(&op_list);
76
77 GraphDef graph_def;
78 TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry)
79 .Attr("a", 42)
80 .Finalize(graph_def.add_node()));
81 GraphDef expected_graph_def = graph_def;
82
83 std::set<std::pair<string, string>> op_attr_removed;
84 TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
85 &op_attr_removed));
86
87 TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
88 EXPECT_TRUE(op_attr_removed.empty());
89 }
90
91 // Producer has default for an attr that the consumer does not know
92 // about, and the produced graph has the default value for the attr ->
93 // attr removed from graph (and so able to be consumed).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UsesDefault)94 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
95 OpList consumer_op_list;
96 TF_ASSERT_OK(
97 FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
98 OpListOpRegistry consumer_registry(&consumer_op_list);
99
100 OpList producer_op_list;
101 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
102 producer_op_list.add_op()));
103 OpListOpRegistry producer_registry(&producer_op_list);
104
105 GraphDef produced_graph_def;
106 TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
107 .Finalize(produced_graph_def.add_node()));
108
109 std::set<std::pair<string, string>> op_attr_removed;
110 TF_ASSERT_OK(
111 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
112 producer_registry, &op_attr_removed));
113
114 GraphDef expected_graph_def;
115 TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
116 .Finalize(expected_graph_def.add_node()));
117 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
118
119 std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
120 EXPECT_EQ(expected_removed, op_attr_removed);
121 }
122
123 // Producer has default for an attr that the consumer does not know
124 // about, graph sets the attr to a value different from the default ->
125 // graph unchanged (but not able to be consumed by consumer).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,ChangedFromDefault)126 TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
127 OpList consumer_op_list;
128 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
129 consumer_op_list.add_op()));
130 OpListOpRegistry consumer_registry(&consumer_op_list);
131
132 OpList producer_op_list;
133 TF_ASSERT_OK(
134 FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
135 producer_op_list.add_op()));
136 OpListOpRegistry producer_registry(&producer_op_list);
137
138 GraphDef produced_graph_def;
139 TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
140 &producer_registry)
141 .Attr("a", 9)
142 .Finalize(produced_graph_def.add_node()));
143 GraphDef expected_graph_def = produced_graph_def;
144
145 std::set<std::pair<string, string>> op_attr_removed;
146 TF_ASSERT_OK(
147 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
148 producer_registry, &op_attr_removed));
149
150 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
151 EXPECT_TRUE(op_attr_removed.empty());
152 }
153
154 // Attrs starting with underscores should not be removed.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UnderscoreAttrs)155 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
156 OpList consumer_op_list;
157 TF_ASSERT_OK(
158 FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
159 OpListOpRegistry consumer_registry(&consumer_op_list);
160
161 OpList producer_op_list;
162 TF_ASSERT_OK(
163 FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
164 // Add the _underscore attr manually since OpDefBuilder would complain
165 OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
166 attr->set_name("_underscore");
167 attr->set_type("int");
168 attr->mutable_default_value()->set_i(17);
169 OpListOpRegistry producer_registry(&producer_op_list);
170
171 GraphDef produced_graph_def;
172 TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
173 .Attr("_underscore", 17)
174 .Finalize(produced_graph_def.add_node()));
175 GraphDef expected_graph_def = produced_graph_def;
176
177 std::set<std::pair<string, string>> op_attr_removed;
178 TF_ASSERT_OK(
179 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
180 producer_registry, &op_attr_removed));
181
182 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
183 EXPECT_EQ(op_attr_removed.size(), 0);
184 }
185
TEST(RemoveNewDefaultAttrsFromGraphDefTest,HasFunction)186 TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
187 OpList consumer_op_list;
188 TF_ASSERT_OK(
189 FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
190 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
191 consumer_op_list.add_op()));
192 OpListOpRegistry consumer_registry(&consumer_op_list);
193
194 OpList producer_op_list;
195 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
196 producer_op_list.add_op()));
197 TF_ASSERT_OK(
198 FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
199 producer_op_list.add_op()));
200 OpListOpRegistry producer_registry(&producer_op_list);
201
202 GraphDef produced_graph_def;
203 *produced_graph_def.mutable_library()->add_function() =
204 FunctionDefHelper::Create(
205 "my_func", {}, {}, {},
206 {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
207 {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
208 {});
209 OpList function_op_list;
210 *function_op_list.add_op() =
211 produced_graph_def.library().function(0).signature();
212 OpListOpRegistry function_registry(&function_op_list);
213 TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
214 .Finalize(produced_graph_def.add_node()));
215
216 std::set<std::pair<string, string>> op_attr_removed;
217 TF_ASSERT_OK(
218 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
219 producer_registry, &op_attr_removed));
220
221 GraphDef expected_graph_def;
222 *expected_graph_def.mutable_library()->add_function() =
223 FunctionDefHelper::Create(
224 "my_func", {}, {}, {},
225 {{{"x"}, "UsesDefault", {}, {}},
226 {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
227 {});
228 TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
229 .Finalize(expected_graph_def.add_node()));
230 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
231 EXPECT_EQ(expected_graph_def.library().DebugString(),
232 produced_graph_def.library().DebugString());
233
234 std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
235 EXPECT_EQ(expected_removed, op_attr_removed);
236 }
237
TEST(StripDefaultAttributesTest,DefaultStripped)238 TEST(StripDefaultAttributesTest, DefaultStripped) {
239 OpList op_list;
240 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
241 op_list.add_op()));
242 OpListOpRegistry registry(&op_list);
243
244 GraphDef graph_def;
245 // This adds the default attribute
246 TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
247 .Finalize(graph_def.add_node()));
248 ASSERT_EQ(1, graph_def.node(0).attr_size());
249 ASSERT_EQ(12, graph_def.node(0).attr().at("a").i());
250
251 StripDefaultAttributes(registry, graph_def.mutable_node());
252 ASSERT_EQ(1, graph_def.node_size());
253 ASSERT_EQ(0, graph_def.node(0).attr_size());
254 }
255
TEST(StripDefaultAttributesTest,NonDefaultNotStripped)256 TEST(StripDefaultAttributesTest, NonDefaultNotStripped) {
257 OpList op_list;
258 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
259 op_list.add_op()));
260 OpListOpRegistry registry(&op_list);
261
262 GraphDef graph_def;
263 TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
264 .Attr("a", 9)
265 .Finalize(graph_def.add_node()));
266
267 GraphDef expected = graph_def;
268 StripDefaultAttributes(registry, graph_def.mutable_node());
269 TF_EXPECT_GRAPH_EQ(expected, graph_def);
270 }
271
TEST(StrippedOpListForGraphTest,FlatTest)272 TEST(StrippedOpListForGraphTest, FlatTest) {
273 // Make four ops
274 OpList op_list;
275 for (const string& op : {"A", "B", "C", "D"}) {
276 OpDef* op_def = op_list.add_op();
277 op_def->set_name(op);
278 op_def->set_summary("summary");
279 op_def->set_description("description");
280 op_def->set_is_commutative(op == "B");
281 }
282
283 // Make a graph which uses two ops once and twice, respectively.
284 // The result should be independent of the ordering.
285 const string graph_ops[4][3] = {
286 {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
287 for (const bool use_function : {false, true}) {
288 for (int order = 0; order < 4; order++) {
289 GraphDef graph_def;
290 if (use_function) {
291 FunctionDef* function_def = graph_def.mutable_library()->add_function();
292 function_def->mutable_signature()->set_name("F");
293 for (const string& op : graph_ops[order]) {
294 function_def->add_node_def()->set_op(op);
295 }
296 graph_def.add_node()->set_op("F");
297 } else {
298 for (const string& op : graph_ops[order]) {
299 string name = strings::StrCat("name", graph_def.node_size());
300 NodeDef* node = graph_def.add_node();
301 node->set_name(name);
302 node->set_op(op);
303 }
304 }
305
306 // Strip the op list
307 OpList stripped_op_list;
308 TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
309 &stripped_op_list));
310
311 // We should have exactly two ops: B and C.
312 ASSERT_EQ(stripped_op_list.op_size(), 2);
313 for (int i = 0; i < 2; i++) {
314 const OpDef& op = stripped_op_list.op(i);
315 EXPECT_EQ(op.name(), i ? "C" : "B");
316 EXPECT_EQ(op.summary(), "");
317 EXPECT_EQ(op.description(), "");
318 EXPECT_EQ(op.is_commutative(), !i);
319 }
320
321 // Should get the same result using OpsUsedByGraph().
322 std::set<string> used_ops;
323 OpsUsedByGraph(graph_def, &used_ops);
324 ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
325 }
326 }
327 }
328
TEST(StrippedOpListForGraphTest,NestedFunctionTest)329 TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
330 // Make a primitive op A.
331 OpList op_list;
332 op_list.add_op()->set_name("A");
333
334 for (const bool recursive : {false, true}) {
335 // Call A from function B, and B from function C.
336 GraphDef graph_def;
337 FunctionDef* b = graph_def.mutable_library()->add_function();
338 FunctionDef* c = graph_def.mutable_library()->add_function();
339 b->mutable_signature()->set_name("B");
340 c->mutable_signature()->set_name("C");
341 b->add_node_def()->set_op("A");
342 c->add_node_def()->set_op("B");
343 if (recursive) {
344 b->add_node_def()->set_op("B");
345 c->add_node_def()->set_op("C");
346 }
347
348 // Use C in the graph.
349 graph_def.add_node()->set_op("C");
350
351 // The stripped op list should contain just A.
352 OpList stripped_op_list;
353 TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
354 &stripped_op_list));
355 ASSERT_EQ(stripped_op_list.op_size(), 1);
356 ASSERT_EQ(stripped_op_list.op(0).name(), "A");
357
358 // Should get the same result using OpsUsedByGraph().
359 std::set<string> used_ops;
360 OpsUsedByGraph(graph_def, &used_ops);
361 ASSERT_EQ(std::set<string>({"A"}), used_ops);
362 }
363 }
364
365 } // namespace
366 } // namespace tensorflow
367