1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status RemoveNodes(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35
36 class RemoveNodesTest : public ::testing::Test {
37 protected:
TestRemoveNodes()38 void TestRemoveNodes() {
39 GraphDef graph_def;
40
41 NodeDef* add_node1 = graph_def.add_node();
42 add_node1->set_name("add_node1");
43 add_node1->set_op("Add");
44 add_node1->add_input("add_node2");
45 add_node1->add_input("add_node3");
46
47 NodeDef* add_node2 = graph_def.add_node();
48 add_node2->set_name("add_node2");
49 add_node2->set_op("Add");
50 add_node2->add_input("identity_node1");
51 add_node2->add_input("identity_node2");
52
53 NodeDef* add_node3 = graph_def.add_node();
54 add_node3->set_name("add_node3");
55 add_node3->set_op("Add");
56 add_node3->add_input("identity_node1");
57 add_node3->add_input("const_node3");
58
59 NodeDef* identity_node1 = graph_def.add_node();
60 identity_node1->set_name("identity_node1");
61 identity_node1->set_op("Identity");
62 identity_node1->add_input("const_node1");
63
64 NodeDef* identity_node2 = graph_def.add_node();
65 identity_node2->set_name("identity_node2");
66 identity_node2->set_op("Identity");
67 identity_node2->add_input("const_node2");
68
69 NodeDef* identity_node3 = graph_def.add_node();
70 identity_node3->set_name("identity_node3");
71 identity_node3->set_op("Identity");
72 identity_node3->add_input("const_node3");
73
74 NodeDef* const_node1 = graph_def.add_node();
75 const_node1->set_name("const_node1");
76 const_node1->set_op("Const");
77
78 NodeDef* const_node2 = graph_def.add_node();
79 const_node2->set_name("const_node2");
80 const_node2->set_op("Const");
81
82 NodeDef* const_node3 = graph_def.add_node();
83 const_node3->set_name("const_node3");
84 const_node3->set_op("Const");
85
86 NodeDef* add_node4 = graph_def.add_node();
87 add_node4->set_name("add_node4");
88 add_node4->set_op("Add");
89 add_node4->add_input("add_node2");
90 add_node4->add_input("add_node3");
91
92 GraphDef result;
93 TransformFuncContext context;
94 context.input_names = {};
95 context.output_names = {"add_node1"};
96 context.params.insert(
97 std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
98 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
99
100 std::map<string, const NodeDef*> node_lookup;
101 MapNamesToNodes(result, &node_lookup);
102 EXPECT_EQ(1, node_lookup.count("add_node1"));
103 EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0));
104 EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1));
105 EXPECT_EQ(1, node_lookup.count("add_node2"));
106 EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0));
107 EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1));
108 EXPECT_EQ(1, node_lookup.count("add_node3"));
109 EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0));
110 EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1));
111 EXPECT_EQ(1, node_lookup.count("add_node4"));
112 EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0));
113 EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1));
114 EXPECT_EQ(0, node_lookup.count("identity_node1"));
115 EXPECT_EQ(0, node_lookup.count("identity_node2"));
116 EXPECT_EQ(0, node_lookup.count("identity_node3"));
117 EXPECT_EQ(1, node_lookup.count("const_node1"));
118 EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
119 EXPECT_EQ(1, node_lookup.count("const_node2"));
120 EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
121 EXPECT_EQ(1, node_lookup.count("const_node3"));
122 EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
123 }
124
TestRemoveOutputNodes()125 void TestRemoveOutputNodes() {
126 GraphDef graph_def;
127
128 NodeDef* const_node1 = graph_def.add_node();
129 const_node1->set_name("const_node1");
130 const_node1->set_op("Const");
131
132 NodeDef* const_node2 = graph_def.add_node();
133 const_node2->set_name("const_node2");
134 const_node2->set_op("Const");
135
136 NodeDef* add_node = graph_def.add_node();
137 add_node->set_name("add_node");
138 add_node->set_op("Add");
139 add_node->add_input("const_node1");
140 add_node->add_input("const_node2");
141
142 NodeDef* identity_node = graph_def.add_node();
143 identity_node->set_name("identity_node");
144 identity_node->set_op("Identity");
145 identity_node->add_input("add_node");
146
147 GraphDef result;
148 TransformFuncContext context;
149 context.input_names = {};
150 context.output_names = {"identity_node"};
151 context.params.insert(
152 std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
153 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
154
155 std::map<string, const NodeDef*> node_lookup;
156 MapNamesToNodes(result, &node_lookup);
157 EXPECT_EQ(1, node_lookup.count("add_node"));
158 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
159 EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
160 EXPECT_EQ(1, node_lookup.count("identity_node"));
161 EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0));
162 }
163
TestRemoveChainedNodes()164 void TestRemoveChainedNodes() {
165 GraphDef graph_def;
166
167 NodeDef* const_node1 = graph_def.add_node();
168 const_node1->set_name("const_node1");
169 const_node1->set_op("Const");
170
171 NodeDef* identity_node1 = graph_def.add_node();
172 identity_node1->set_name("identity_node1");
173 identity_node1->set_op("Identity");
174 identity_node1->add_input("const_node1");
175
176 NodeDef* identity_node2 = graph_def.add_node();
177 identity_node2->set_name("identity_node2");
178 identity_node2->set_op("Identity");
179 identity_node2->add_input("identity_node1");
180
181 NodeDef* identity_node3 = graph_def.add_node();
182 identity_node3->set_name("identity_node3");
183 identity_node3->set_op("Identity");
184 identity_node3->add_input("identity_node2");
185
186 NodeDef* const_node2 = graph_def.add_node();
187 const_node2->set_name("const_node2");
188 const_node2->set_op("Const");
189
190 NodeDef* add_node = graph_def.add_node();
191 add_node->set_name("add_node");
192 add_node->set_op("Add");
193 add_node->add_input("identity_node3");
194 add_node->add_input("const_node2");
195
196 GraphDef result;
197 TransformFuncContext context;
198 context.input_names = {};
199 context.output_names = {"identity_node"};
200 context.params.insert(
201 std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
202 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
203
204 std::map<string, const NodeDef*> node_lookup;
205 MapNamesToNodes(result, &node_lookup);
206 EXPECT_EQ(1, node_lookup.count("add_node"));
207 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
208 EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
209 EXPECT_EQ(0, node_lookup.count("identity_node1"));
210 EXPECT_EQ(0, node_lookup.count("identity_node2"));
211 EXPECT_EQ(0, node_lookup.count("identity_node3"));
212 }
213
TestRemoveMultipleInputs()214 void TestRemoveMultipleInputs() {
215 GraphDef graph_def;
216
217 NodeDef* const_node1 = graph_def.add_node();
218 const_node1->set_name("const_node1");
219 const_node1->set_op("Const");
220
221 NodeDef* const_node2 = graph_def.add_node();
222 const_node2->set_name("const_node2");
223 const_node2->set_op("Const");
224
225 NodeDef* const_node3 = graph_def.add_node();
226 const_node3->set_name("const_node3");
227 const_node3->set_op("Const");
228
229 NodeDef* const_node4 = graph_def.add_node();
230 const_node4->set_name("const_node4");
231 const_node4->set_op("Const");
232
233 NodeDef* fake_quant_node = graph_def.add_node();
234 fake_quant_node->set_name("fake_quant_node");
235 fake_quant_node->set_op("FakeQuantWithMinMaxVars");
236 fake_quant_node->add_input("const_node1");
237 fake_quant_node->add_input("const_node2");
238 fake_quant_node->add_input("const_node3");
239
240 NodeDef* add_node = graph_def.add_node();
241 add_node->set_name("add_node");
242 add_node->set_op("Add");
243 add_node->add_input("fake_quant_node");
244 add_node->add_input("const_node4");
245
246 GraphDef result;
247 TransformFuncContext context;
248 context.input_names = {};
249 context.output_names = {"add_node"};
250 context.params.insert(std::pair<string, std::vector<string>>(
251 {"op", {string("FakeQuantWithMinMaxVars")}}));
252 context.params.insert(
253 std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
254 TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
255
256 std::map<string, const NodeDef*> node_lookup;
257 MapNamesToNodes(result, &node_lookup);
258 ASSERT_EQ(1, node_lookup.count("const_node1"));
259 ASSERT_EQ(1, node_lookup.count("const_node4"));
260 ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
261 ASSERT_EQ(1, node_lookup.count("add_node"));
262 EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
263 EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
264 }
265 };
266
TEST_F(RemoveNodesTest,TestRemoveNodes)267 TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
268
TEST_F(RemoveNodesTest,TestRemoveOutputNodes)269 TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
270
TEST_F(RemoveNodesTest,TestRemoveChainedNodes)271 TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
272
TEST_F(RemoveNodesTest,TestRemoveMultipleInputs)273 TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
274 TestRemoveMultipleInputs();
275 }
276
277 } // namespace graph_transforms
278 } // namespace tensorflow
279