1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/graph_def_util.h"
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27
28 namespace tensorflow {
29 namespace {
30
FinalizeOpDef(const OpDefBuilder & b,OpDef * op_def)31 Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
32 OpRegistrationData op_reg_data;
33 const Status s = b.Finalize(&op_reg_data);
34 *op_def = op_reg_data.op_def;
35 return s;
36 }
37
38 // Producer and consumer have default for an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeWithDefault)39 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
40 OpList op_list;
41 TF_ASSERT_OK(
42 FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
43 op_list.add_op()));
44 OpListOpRegistry registry(&op_list);
45
46 GraphDef graph_def;
47 TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry)
48 .Finalize(graph_def.add_node()));
49 GraphDef expected_graph_def = graph_def;
50
51 std::set<std::pair<string, string>> op_attr_removed;
52 TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
53 &op_attr_removed));
54
55 TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
56 EXPECT_TRUE(op_attr_removed.empty());
57 }
58
59 // Producer and consumer both have an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeNoDefault)60 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
61 OpList op_list;
62 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
63 op_list.add_op()));
64 OpListOpRegistry registry(&op_list);
65
66 GraphDef graph_def;
67 TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry)
68 .Attr("a", 42)
69 .Finalize(graph_def.add_node()));
70 GraphDef expected_graph_def = graph_def;
71
72 std::set<std::pair<string, string>> op_attr_removed;
73 TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
74 &op_attr_removed));
75
76 TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
77 EXPECT_TRUE(op_attr_removed.empty());
78 }
79
80 // Producer has default for an attr that the consumer does not know
81 // about, and the produced graph has the default value for the attr ->
82 // attr removed from graph (and so able to be consumed).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UsesDefault)83 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
84 OpList consumer_op_list;
85 TF_ASSERT_OK(
86 FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
87 OpListOpRegistry consumer_registry(&consumer_op_list);
88
89 OpList producer_op_list;
90 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
91 producer_op_list.add_op()));
92 OpListOpRegistry producer_registry(&producer_op_list);
93
94 GraphDef produced_graph_def;
95 TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
96 .Finalize(produced_graph_def.add_node()));
97
98 std::set<std::pair<string, string>> op_attr_removed;
99 TF_ASSERT_OK(
100 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
101 producer_registry, &op_attr_removed));
102
103 GraphDef expected_graph_def;
104 TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
105 .Finalize(expected_graph_def.add_node()));
106 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
107
108 std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
109 EXPECT_EQ(expected_removed, op_attr_removed);
110 }
111
112 // Producer has default for an attr that the consumer does not know
113 // about, graph sets the attr to a value different from the default ->
114 // graph unchanged (but not able to be consumed by consumer).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,ChangedFromDefault)115 TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
116 OpList consumer_op_list;
117 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
118 consumer_op_list.add_op()));
119 OpListOpRegistry consumer_registry(&consumer_op_list);
120
121 OpList producer_op_list;
122 TF_ASSERT_OK(
123 FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
124 producer_op_list.add_op()));
125 OpListOpRegistry producer_registry(&producer_op_list);
126
127 GraphDef produced_graph_def;
128 TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
129 &producer_registry)
130 .Attr("a", 9)
131 .Finalize(produced_graph_def.add_node()));
132 GraphDef expected_graph_def = produced_graph_def;
133
134 std::set<std::pair<string, string>> op_attr_removed;
135 TF_ASSERT_OK(
136 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
137 producer_registry, &op_attr_removed));
138
139 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
140 EXPECT_TRUE(op_attr_removed.empty());
141 }
142
143 // Attrs starting with underscores should not be removed.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UnderscoreAttrs)144 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
145 OpList consumer_op_list;
146 TF_ASSERT_OK(
147 FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
148 OpListOpRegistry consumer_registry(&consumer_op_list);
149
150 OpList producer_op_list;
151 TF_ASSERT_OK(
152 FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
153 // Add the _underscore attr manually since OpDefBuilder would complain
154 OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
155 attr->set_name("_underscore");
156 attr->set_type("int");
157 attr->mutable_default_value()->set_i(17);
158 OpListOpRegistry producer_registry(&producer_op_list);
159
160 GraphDef produced_graph_def;
161 TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
162 .Attr("_underscore", 17)
163 .Finalize(produced_graph_def.add_node()));
164 GraphDef expected_graph_def = produced_graph_def;
165
166 std::set<std::pair<string, string>> op_attr_removed;
167 TF_ASSERT_OK(
168 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
169 producer_registry, &op_attr_removed));
170
171 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
172 EXPECT_EQ(op_attr_removed.size(), 0);
173 }
174
TEST(RemoveNewDefaultAttrsFromGraphDefTest,HasFunction)175 TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
176 OpList consumer_op_list;
177 TF_ASSERT_OK(
178 FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
179 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
180 consumer_op_list.add_op()));
181 OpListOpRegistry consumer_registry(&consumer_op_list);
182
183 OpList producer_op_list;
184 TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
185 producer_op_list.add_op()));
186 TF_ASSERT_OK(
187 FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
188 producer_op_list.add_op()));
189 OpListOpRegistry producer_registry(&producer_op_list);
190
191 GraphDef produced_graph_def;
192 *produced_graph_def.mutable_library()->add_function() =
193 FunctionDefHelper::Create(
194 "my_func", {}, {}, {},
195 {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
196 {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
197 {});
198 OpList function_op_list;
199 *function_op_list.add_op() =
200 produced_graph_def.library().function(0).signature();
201 OpListOpRegistry function_registry(&function_op_list);
202 TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
203 .Finalize(produced_graph_def.add_node()));
204
205 std::set<std::pair<string, string>> op_attr_removed;
206 TF_ASSERT_OK(
207 RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
208 producer_registry, &op_attr_removed));
209
210 GraphDef expected_graph_def;
211 *expected_graph_def.mutable_library()->add_function() =
212 FunctionDefHelper::Create(
213 "my_func", {}, {}, {},
214 {{{"x"}, "UsesDefault", {}, {}},
215 {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
216 {});
217 TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
218 .Finalize(expected_graph_def.add_node()));
219 TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
220 EXPECT_EQ(expected_graph_def.library().DebugString(),
221 produced_graph_def.library().DebugString());
222
223 std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
224 EXPECT_EQ(expected_removed, op_attr_removed);
225 }
226
TEST(StrippedOpListForGraphTest,FlatTest)227 TEST(StrippedOpListForGraphTest, FlatTest) {
228 // Make four ops
229 OpList op_list;
230 for (const string& op : {"A", "B", "C", "D"}) {
231 OpDef* op_def = op_list.add_op();
232 op_def->set_name(op);
233 op_def->set_summary("summary");
234 op_def->set_description("description");
235 op_def->set_is_commutative(op == "B");
236 }
237
238 // Make a graph which uses two ops once and twice, respectively.
239 // The result should be independent of the ordering.
240 const string graph_ops[4][3] = {
241 {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
242 for (const bool use_function : {false, true}) {
243 for (int order = 0; order < 4; order++) {
244 GraphDef graph_def;
245 if (use_function) {
246 FunctionDef* function_def = graph_def.mutable_library()->add_function();
247 function_def->mutable_signature()->set_name("F");
248 for (const string& op : graph_ops[order]) {
249 function_def->add_node_def()->set_op(op);
250 }
251 graph_def.add_node()->set_op("F");
252 } else {
253 for (const string& op : graph_ops[order]) {
254 string name = strings::StrCat("name", graph_def.node_size());
255 NodeDef* node = graph_def.add_node();
256 node->set_name(name);
257 node->set_op(op);
258 }
259 }
260
261 // Strip the op list
262 OpList stripped_op_list;
263 TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
264 &stripped_op_list));
265
266 // We should have exactly two ops: B and C.
267 ASSERT_EQ(stripped_op_list.op_size(), 2);
268 for (int i = 0; i < 2; i++) {
269 const OpDef& op = stripped_op_list.op(i);
270 EXPECT_EQ(op.name(), i ? "C" : "B");
271 EXPECT_EQ(op.summary(), "");
272 EXPECT_EQ(op.description(), "");
273 EXPECT_EQ(op.is_commutative(), !i);
274 }
275
276 // Should get the same result using OpsUsedByGraph().
277 std::set<string> used_ops;
278 OpsUsedByGraph(graph_def, &used_ops);
279 ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
280 }
281 }
282 }
283
TEST(StrippedOpListForGraphTest,NestedFunctionTest)284 TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
285 // Make a primitive op A.
286 OpList op_list;
287 op_list.add_op()->set_name("A");
288
289 for (const bool recursive : {false, true}) {
290 // Call A from function B, and B from function C.
291 GraphDef graph_def;
292 FunctionDef* b = graph_def.mutable_library()->add_function();
293 FunctionDef* c = graph_def.mutable_library()->add_function();
294 b->mutable_signature()->set_name("B");
295 c->mutable_signature()->set_name("C");
296 b->add_node_def()->set_op("A");
297 c->add_node_def()->set_op("B");
298 if (recursive) {
299 b->add_node_def()->set_op("B");
300 c->add_node_def()->set_op("C");
301 }
302
303 // Use C in the graph.
304 graph_def.add_node()->set_op("C");
305
306 // The stripped op list should contain just A.
307 OpList stripped_op_list;
308 TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
309 &stripped_op_list));
310 ASSERT_EQ(stripped_op_list.op_size(), 1);
311 ASSERT_EQ(stripped_op_list.op(0).name(), "A");
312
313 // Should get the same result using OpsUsedByGraph().
314 std::set<string> used_ops;
315 OpsUsedByGraph(graph_def, &used_ops);
316 ASSERT_EQ(std::set<string>({"A"}), used_ops);
317 }
318 }
319
320 } // namespace
321 } // namespace tensorflow
322