• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/graph_def_util.h"
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/equal_graph_def.h"
27 
28 namespace tensorflow {
29 namespace {
30 
FinalizeOpDef(const OpDefBuilder & b,OpDef * op_def)31 Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
32   OpRegistrationData op_reg_data;
33   const Status s = b.Finalize(&op_reg_data);
34   *op_def = op_reg_data.op_def;
35   return s;
36 }
37 
38 // We can create a Graph containing a namespaced Op
TEST(AddToGraphTest,MakeGraphDefWithNamespacedOpName)39 TEST(AddToGraphTest, MakeGraphDefWithNamespacedOpName) {
40   OpList op_list;
41   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("Project>SomeOp"), op_list.add_op()));
42   OpListOpRegistry registry(&op_list);
43 
44   GraphDef graph_def;
45   TF_ASSERT_OK(NodeDefBuilder("node", "Project>SomeOp", &registry)
46                    .Finalize(graph_def.add_node()));
47 }
48 
49 // Producer and consumer have default for an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeWithDefault)50 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
51   OpList op_list;
52   TF_ASSERT_OK(
53       FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
54                     op_list.add_op()));
55   OpListOpRegistry registry(&op_list);
56 
57   GraphDef graph_def;
58   TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", &registry)
59                    .Finalize(graph_def.add_node()));
60   GraphDef expected_graph_def = graph_def;
61 
62   std::set<std::pair<string, string>> op_attr_removed;
63   TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
64                                                  &op_attr_removed));
65 
66   TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
67   EXPECT_TRUE(op_attr_removed.empty());
68 }
69 
70 // Producer and consumer both have an attr -> graph unchanged.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,NoChangeNoDefault)71 TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
72   OpList op_list;
73   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
74                              op_list.add_op()));
75   OpListOpRegistry registry(&op_list);
76 
77   GraphDef graph_def;
78   TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", &registry)
79                    .Attr("a", 42)
80                    .Finalize(graph_def.add_node()));
81   GraphDef expected_graph_def = graph_def;
82 
83   std::set<std::pair<string, string>> op_attr_removed;
84   TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
85                                                  &op_attr_removed));
86 
87   TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
88   EXPECT_TRUE(op_attr_removed.empty());
89 }
90 
91 // Producer has default for an attr that the consumer does not know
92 // about, and the produced graph has the default value for the attr ->
93 // attr removed from graph (and so able to be consumed).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UsesDefault)94 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
95   OpList consumer_op_list;
96   TF_ASSERT_OK(
97       FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
98   OpListOpRegistry consumer_registry(&consumer_op_list);
99 
100   OpList producer_op_list;
101   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
102                              producer_op_list.add_op()));
103   OpListOpRegistry producer_registry(&producer_op_list);
104 
105   GraphDef produced_graph_def;
106   TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
107                    .Finalize(produced_graph_def.add_node()));
108 
109   std::set<std::pair<string, string>> op_attr_removed;
110   TF_ASSERT_OK(
111       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
112                                         producer_registry, &op_attr_removed));
113 
114   GraphDef expected_graph_def;
115   TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
116                    .Finalize(expected_graph_def.add_node()));
117   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
118 
119   std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
120   EXPECT_EQ(expected_removed, op_attr_removed);
121 }
122 
123 // Producer has default for an attr that the consumer does not know
124 // about, graph sets the attr to a value different from the default ->
125 // graph unchanged (but not able to be consumed by consumer).
TEST(RemoveNewDefaultAttrsFromGraphDefTest,ChangedFromDefault)126 TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
127   OpList consumer_op_list;
128   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
129                              consumer_op_list.add_op()));
130   OpListOpRegistry consumer_registry(&consumer_op_list);
131 
132   OpList producer_op_list;
133   TF_ASSERT_OK(
134       FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
135                     producer_op_list.add_op()));
136   OpListOpRegistry producer_registry(&producer_op_list);
137 
138   GraphDef produced_graph_def;
139   TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
140                               &producer_registry)
141                    .Attr("a", 9)
142                    .Finalize(produced_graph_def.add_node()));
143   GraphDef expected_graph_def = produced_graph_def;
144 
145   std::set<std::pair<string, string>> op_attr_removed;
146   TF_ASSERT_OK(
147       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
148                                         producer_registry, &op_attr_removed));
149 
150   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
151   EXPECT_TRUE(op_attr_removed.empty());
152 }
153 
154 // Attrs starting with underscores should not be removed.
TEST(RemoveNewDefaultAttrsFromGraphDefTest,UnderscoreAttrs)155 TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
156   OpList consumer_op_list;
157   TF_ASSERT_OK(
158       FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
159   OpListOpRegistry consumer_registry(&consumer_op_list);
160 
161   OpList producer_op_list;
162   TF_ASSERT_OK(
163       FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
164   // Add the _underscore attr manually since OpDefBuilder would complain
165   OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
166   attr->set_name("_underscore");
167   attr->set_type("int");
168   attr->mutable_default_value()->set_i(17);
169   OpListOpRegistry producer_registry(&producer_op_list);
170 
171   GraphDef produced_graph_def;
172   TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
173                    .Attr("_underscore", 17)
174                    .Finalize(produced_graph_def.add_node()));
175   GraphDef expected_graph_def = produced_graph_def;
176 
177   std::set<std::pair<string, string>> op_attr_removed;
178   TF_ASSERT_OK(
179       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
180                                         producer_registry, &op_attr_removed));
181 
182   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
183   EXPECT_EQ(op_attr_removed.size(), 0);
184 }
185 
TEST(RemoveNewDefaultAttrsFromGraphDefTest,HasFunction)186 TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
187   OpList consumer_op_list;
188   TF_ASSERT_OK(
189       FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
190   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
191                              consumer_op_list.add_op()));
192   OpListOpRegistry consumer_registry(&consumer_op_list);
193 
194   OpList producer_op_list;
195   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
196                              producer_op_list.add_op()));
197   TF_ASSERT_OK(
198       FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
199                     producer_op_list.add_op()));
200   OpListOpRegistry producer_registry(&producer_op_list);
201 
202   GraphDef produced_graph_def;
203   *produced_graph_def.mutable_library()->add_function() =
204       FunctionDefHelper::Create(
205           "my_func", {}, {}, {},
206           {{{"x"}, "UsesDefault", {}, {{"a", 17}}},
207            {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
208           {});
209   OpList function_op_list;
210   *function_op_list.add_op() =
211       produced_graph_def.library().function(0).signature();
212   OpListOpRegistry function_registry(&function_op_list);
213   TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
214                    .Finalize(produced_graph_def.add_node()));
215 
216   std::set<std::pair<string, string>> op_attr_removed;
217   TF_ASSERT_OK(
218       RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
219                                         producer_registry, &op_attr_removed));
220 
221   GraphDef expected_graph_def;
222   *expected_graph_def.mutable_library()->add_function() =
223       FunctionDefHelper::Create(
224           "my_func", {}, {}, {},
225           {{{"x"}, "UsesDefault", {}, {}},
226            {{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
227           {});
228   TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
229                    .Finalize(expected_graph_def.add_node()));
230   TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
231   EXPECT_EQ(expected_graph_def.library().DebugString(),
232             produced_graph_def.library().DebugString());
233 
234   std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
235   EXPECT_EQ(expected_removed, op_attr_removed);
236 }
237 
TEST(StripDefaultAttributesTest,DefaultStripped)238 TEST(StripDefaultAttributesTest, DefaultStripped) {
239   OpList op_list;
240   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
241                              op_list.add_op()));
242   OpListOpRegistry registry(&op_list);
243 
244   GraphDef graph_def;
245   // This adds the default attribute
246   TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", &registry)
247                    .Finalize(graph_def.add_node()));
248   ASSERT_EQ(1, graph_def.node(0).attr_size());
249   ASSERT_EQ(12, graph_def.node(0).attr().at("a").i());
250 
251   StripDefaultAttributes(registry, graph_def.mutable_node());
252   ASSERT_EQ(1, graph_def.node_size());
253   ASSERT_EQ(0, graph_def.node(0).attr_size());
254 }
255 
TEST(StripDefaultAttributesTest,NonDefaultNotStripped)256 TEST(StripDefaultAttributesTest, NonDefaultNotStripped) {
257   OpList op_list;
258   TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
259                              op_list.add_op()));
260   OpListOpRegistry registry(&op_list);
261 
262   GraphDef graph_def;
263   TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", &registry)
264                    .Attr("a", 9)
265                    .Finalize(graph_def.add_node()));
266 
267   GraphDef expected = graph_def;
268   StripDefaultAttributes(registry, graph_def.mutable_node());
269   TF_EXPECT_GRAPH_EQ(expected, graph_def);
270 }
271 
TEST(StrippedOpListForGraphTest,FlatTest)272 TEST(StrippedOpListForGraphTest, FlatTest) {
273   // Make four ops
274   OpList op_list;
275   for (const string& op : {"A", "B", "C", "D"}) {
276     OpDef* op_def = op_list.add_op();
277     op_def->set_name(op);
278     op_def->set_summary("summary");
279     op_def->set_description("description");
280     op_def->set_is_commutative(op == "B");
281   }
282 
283   // Make a graph which uses two ops once and twice, respectively.
284   // The result should be independent of the ordering.
285   const string graph_ops[4][3] = {
286       {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
287   for (const bool use_function : {false, true}) {
288     for (int order = 0; order < 4; order++) {
289       GraphDef graph_def;
290       if (use_function) {
291         FunctionDef* function_def = graph_def.mutable_library()->add_function();
292         function_def->mutable_signature()->set_name("F");
293         for (const string& op : graph_ops[order]) {
294           function_def->add_node_def()->set_op(op);
295         }
296         graph_def.add_node()->set_op("F");
297       } else {
298         for (const string& op : graph_ops[order]) {
299           string name = strings::StrCat("name", graph_def.node_size());
300           NodeDef* node = graph_def.add_node();
301           node->set_name(name);
302           node->set_op(op);
303         }
304       }
305 
306       // Strip the op list
307       OpList stripped_op_list;
308       TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
309                                           &stripped_op_list));
310 
311       // We should have exactly two ops: B and C.
312       ASSERT_EQ(stripped_op_list.op_size(), 2);
313       for (int i = 0; i < 2; i++) {
314         const OpDef& op = stripped_op_list.op(i);
315         EXPECT_EQ(op.name(), i ? "C" : "B");
316         EXPECT_EQ(op.summary(), "");
317         EXPECT_EQ(op.description(), "");
318         EXPECT_EQ(op.is_commutative(), !i);
319       }
320 
321       // Should get the same result using OpsUsedByGraph().
322       std::set<string> used_ops;
323       OpsUsedByGraph(graph_def, &used_ops);
324       ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
325     }
326   }
327 }
328 
TEST(StrippedOpListForGraphTest,NestedFunctionTest)329 TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
330   // Make a primitive op A.
331   OpList op_list;
332   op_list.add_op()->set_name("A");
333 
334   for (const bool recursive : {false, true}) {
335     // Call A from function B, and B from function C.
336     GraphDef graph_def;
337     FunctionDef* b = graph_def.mutable_library()->add_function();
338     FunctionDef* c = graph_def.mutable_library()->add_function();
339     b->mutable_signature()->set_name("B");
340     c->mutable_signature()->set_name("C");
341     b->add_node_def()->set_op("A");
342     c->add_node_def()->set_op("B");
343     if (recursive) {
344       b->add_node_def()->set_op("B");
345       c->add_node_def()->set_op("C");
346     }
347 
348     // Use C in the graph.
349     graph_def.add_node()->set_op("C");
350 
351     // The stripped op list should contain just A.
352     OpList stripped_op_list;
353     TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
354                                         &stripped_op_list));
355     ASSERT_EQ(stripped_op_list.op_size(), 1);
356     ASSERT_EQ(stripped_op_list.op(0).name(), "A");
357 
358     // Should get the same result using OpsUsedByGraph().
359     std::set<string> used_ops;
360     OpsUsedByGraph(graph_def, &used_ops);
361     ASSERT_EQ(std::set<string>({"A"}), used_ops);
362   }
363 }
364 
365 }  // namespace
366 }  // namespace tensorflow
367