1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28
29 namespace tensorflow {
30 namespace graph_transforms {
31
32 // Declare here, so we don't need a public header.
33 Status StripUnusedNodes(const GraphDef& input_graph_def,
34 const TransformFuncContext& context,
35 GraphDef* output_graph_def);
36
37 class StripUnusedNodesTest : public ::testing::Test {
38 protected:
TestSimpleAdd()39 void TestSimpleAdd() {
40 GraphDef graph_def;
41 NodeDef* add_node = graph_def.add_node();
42 add_node->set_name("add_node");
43 add_node->set_op("Add");
44 add_node->add_input("a_node");
45 add_node->add_input("b_node");
46
47 NodeDef* a_node = graph_def.add_node();
48 a_node->set_name("a_node");
49 a_node->set_op("Const");
50
51 NodeDef* b_node = graph_def.add_node();
52 b_node->set_name("b_node");
53 b_node->set_op("Const");
54
55 NodeDef* c_node = graph_def.add_node();
56 c_node->set_name("c_node");
57 c_node->set_op("Const");
58
59 GraphDef result;
60 TF_ASSERT_OK(StripUnusedNodes(graph_def, {{}, {"add_node"}}, &result));
61
62 std::map<string, const NodeDef*> node_lookup;
63 MapNamesToNodes(result, &node_lookup);
64 EXPECT_EQ(1, node_lookup.count("add_node"));
65 EXPECT_EQ(1, node_lookup.count("a_node"));
66 EXPECT_EQ(1, node_lookup.count("b_node"));
67 EXPECT_EQ(0, node_lookup.count("c_node"));
68 }
69
TestCommonAncestor()70 void TestCommonAncestor() {
71 GraphDef graph_def;
72
73 NodeDef* add_node1 = graph_def.add_node();
74 add_node1->set_name("add_node1");
75 add_node1->set_op("Add");
76 add_node1->add_input("add_node2");
77 add_node1->add_input("add_node3");
78
79 NodeDef* add_node2 = graph_def.add_node();
80 add_node2->set_name("add_node2");
81 add_node2->set_op("Add");
82 add_node2->add_input("const_node1");
83 add_node2->add_input("const_node2");
84
85 NodeDef* add_node3 = graph_def.add_node();
86 add_node3->set_name("add_node3");
87 add_node3->set_op("Add");
88 add_node3->add_input("const_node1");
89 add_node3->add_input("const_node3");
90
91 NodeDef* const_node1 = graph_def.add_node();
92 const_node1->set_name("const_node1");
93 const_node1->set_op("Const");
94
95 NodeDef* const_node2 = graph_def.add_node();
96 const_node2->set_name("const_node2");
97 const_node2->set_op("Const");
98
99 NodeDef* const_node3 = graph_def.add_node();
100 const_node3->set_name("const_node3");
101 const_node3->set_op("Const");
102
103 NodeDef* dangling_input = graph_def.add_node();
104 dangling_input->set_name("dangling_input");
105 dangling_input->set_op("Const");
106
107 NodeDef* add_node4 = graph_def.add_node();
108 add_node4->set_name("add_node4");
109 add_node4->set_op("Add");
110 add_node4->add_input("add_node2");
111 add_node4->add_input("add_node3");
112
113 GraphDef result;
114 TF_ASSERT_OK(StripUnusedNodes(
115 graph_def, {{"dangling_input"}, {"add_node1"}}, &result));
116
117 std::map<string, const NodeDef*> node_lookup;
118 MapNamesToNodes(result, &node_lookup);
119 EXPECT_EQ(1, node_lookup.count("add_node1"));
120 EXPECT_EQ(1, node_lookup.count("add_node2"));
121 EXPECT_EQ(1, node_lookup.count("add_node3"));
122 EXPECT_EQ(0, node_lookup.count("add_node4"));
123 EXPECT_EQ(1, node_lookup.count("const_node1"));
124 EXPECT_EQ(1, node_lookup.count("const_node2"));
125 EXPECT_EQ(1, node_lookup.count("const_node3"));
126 EXPECT_EQ(0, node_lookup.count("const_node4"));
127 EXPECT_EQ(1, node_lookup.count("dangling_input"));
128 }
129
TestSimplePlaceholder()130 void TestSimplePlaceholder() {
131 GraphDef graph_def;
132 NodeDef* add_node = graph_def.add_node();
133 add_node->set_name("add_node");
134 add_node->set_op("Add");
135 add_node->add_input("mul_node");
136 add_node->add_input("a_node");
137
138 NodeDef* mul_node = graph_def.add_node();
139 mul_node->set_name("mul_node");
140 mul_node->set_op("Mul");
141 mul_node->add_input("b_node");
142 mul_node->add_input("c_node");
143
144 NodeDef* a_node = graph_def.add_node();
145 a_node->set_name("a_node");
146 a_node->set_op("Const");
147
148 NodeDef* b_node = graph_def.add_node();
149 b_node->set_name("b_node");
150 b_node->set_op("Const");
151
152 NodeDef* c_node = graph_def.add_node();
153 c_node->set_name("c_node");
154 c_node->set_op("Const");
155
156 GraphDef result;
157 TF_ASSERT_OK(
158 StripUnusedNodes(graph_def, {{"mul_node"}, {"add_node"}}, &result));
159
160 std::map<string, const NodeDef*> node_lookup;
161 MapNamesToNodes(result, &node_lookup);
162 EXPECT_EQ(1, node_lookup.count("add_node"));
163 EXPECT_EQ(1, node_lookup.count("mul_node"));
164 EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
165 EXPECT_EQ(DT_FLOAT, node_lookup["mul_node"]->attr().at("dtype").type());
166 EXPECT_EQ(TensorShape({}),
167 TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
168 EXPECT_EQ(1, node_lookup.count("a_node"));
169 EXPECT_EQ(0, node_lookup.count("b_node"));
170 EXPECT_EQ(0, node_lookup.count("c_node"));
171 }
172
TestPlaceholderDefaultArgs()173 void TestPlaceholderDefaultArgs() {
174 GraphDef graph_def;
175 NodeDef* add_node = graph_def.add_node();
176 add_node->set_name("add_node");
177 add_node->set_op("Add");
178 add_node->add_input("mul_node");
179 add_node->add_input("a_node");
180
181 NodeDef* mul_node = graph_def.add_node();
182 mul_node->set_name("mul_node");
183 mul_node->set_op("Mul");
184 mul_node->add_input("b_node");
185 mul_node->add_input("c_node");
186
187 NodeDef* a_node = graph_def.add_node();
188 a_node->set_name("a_node");
189 a_node->set_op("Const");
190
191 NodeDef* b_node = graph_def.add_node();
192 b_node->set_name("b_node");
193 b_node->set_op("Const");
194
195 NodeDef* c_node = graph_def.add_node();
196 c_node->set_name("c_node");
197 c_node->set_op("Const");
198
199 GraphDef result;
200 TF_ASSERT_OK(StripUnusedNodes(graph_def,
201 {{"mul_node"},
202 {"add_node"},
203 {{"type", {"int32"}}, {"shape", {"1,2,3"}}}},
204 &result));
205
206 std::map<string, const NodeDef*> node_lookup;
207 MapNamesToNodes(result, &node_lookup);
208 EXPECT_EQ(1, node_lookup.count("add_node"));
209 EXPECT_EQ(1, node_lookup.count("mul_node"));
210 EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
211 EXPECT_EQ(DT_INT32, node_lookup["mul_node"]->attr().at("dtype").type());
212 EXPECT_EQ(TensorShape({1, 2, 3}),
213 TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
214 EXPECT_EQ(1, node_lookup.count("a_node"));
215 EXPECT_EQ(0, node_lookup.count("b_node"));
216 EXPECT_EQ(0, node_lookup.count("c_node"));
217 }
218
TestPlaceholderNamedArgs()219 void TestPlaceholderNamedArgs() {
220 GraphDef graph_def;
221 NodeDef* add_node = graph_def.add_node();
222 add_node->set_name("add_node");
223 add_node->set_op("Add");
224 add_node->add_input("mul_node");
225 add_node->add_input("a_node");
226
227 NodeDef* mul_node = graph_def.add_node();
228 mul_node->set_name("mul_node");
229 mul_node->set_op("Mul");
230 mul_node->add_input("b_node");
231 mul_node->add_input("c_node");
232
233 NodeDef* a_node = graph_def.add_node();
234 a_node->set_name("a_node");
235 a_node->set_op("Const");
236
237 NodeDef* b_node = graph_def.add_node();
238 b_node->set_name("b_node");
239 b_node->set_op("Const");
240
241 NodeDef* c_node = graph_def.add_node();
242 c_node->set_name("c_node");
243 c_node->set_op("Const");
244
245 GraphDef result;
246 TF_ASSERT_OK(StripUnusedNodes(graph_def,
247 {{"mul_node", "a_node"},
248 {"add_node"},
249 {{"name", {"a_node", "mul_node"}},
250 {"type_for_name", {"int64", "quint8"}},
251 {"shape_for_name", {"1,2", "1, 2, 3"}}}},
252 &result));
253
254 std::map<string, const NodeDef*> node_lookup;
255 MapNamesToNodes(result, &node_lookup);
256 EXPECT_EQ(1, node_lookup.count("add_node"));
257 EXPECT_EQ(1, node_lookup.count("mul_node"));
258 EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
259 EXPECT_EQ(DT_QUINT8, node_lookup["mul_node"]->attr().at("dtype").type());
260 EXPECT_EQ(TensorShape({1, 2, 3}),
261 TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
262 EXPECT_EQ(1, node_lookup.count("a_node"));
263 EXPECT_EQ("Placeholder", node_lookup["a_node"]->op());
264 EXPECT_EQ(DT_INT64, node_lookup["a_node"]->attr().at("dtype").type());
265 EXPECT_EQ(TensorShape({1, 2}),
266 TensorShape(node_lookup["a_node"]->attr().at("shape").shape()));
267 EXPECT_EQ(0, node_lookup.count("b_node"));
268 EXPECT_EQ(0, node_lookup.count("c_node"));
269 }
270 };
271
TEST_F(StripUnusedNodesTest,TestSimpleAdd)272 TEST_F(StripUnusedNodesTest, TestSimpleAdd) { TestSimpleAdd(); }
273
TEST_F(StripUnusedNodesTest,TestCommonAncestor)274 TEST_F(StripUnusedNodesTest, TestCommonAncestor) { TestCommonAncestor(); }
275
TEST_F(StripUnusedNodesTest,TestSimplePlaceholder)276 TEST_F(StripUnusedNodesTest, TestSimplePlaceholder) { TestSimplePlaceholder(); }
277
TEST_F(StripUnusedNodesTest,TestPlaceholderDefaultArgs)278 TEST_F(StripUnusedNodesTest, TestPlaceholderDefaultArgs) {
279 TestPlaceholderDefaultArgs();
280 }
281
TEST_F(StripUnusedNodesTest,TestPlaceholderNamedArgs)282 TEST_F(StripUnusedNodesTest, TestPlaceholderNamedArgs) {
283 TestPlaceholderNamedArgs();
284 }
285
286 } // namespace graph_transforms
287 } // namespace tensorflow
288