• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Declare here, so we don't need a public header.
32 Status RemoveNodes(const GraphDef& input_graph_def,
33                    const TransformFuncContext& context,
34                    GraphDef* output_graph_def);
35 
36 class RemoveNodesTest : public ::testing::Test {
37  protected:
TestRemoveNodes()38   void TestRemoveNodes() {
39     GraphDef graph_def;
40 
41     NodeDef* add_node1 = graph_def.add_node();
42     add_node1->set_name("add_node1");
43     add_node1->set_op("Add");
44     add_node1->add_input("add_node2");
45     add_node1->add_input("add_node3");
46 
47     NodeDef* add_node2 = graph_def.add_node();
48     add_node2->set_name("add_node2");
49     add_node2->set_op("Add");
50     add_node2->add_input("identity_node1");
51     add_node2->add_input("identity_node2");
52 
53     NodeDef* add_node3 = graph_def.add_node();
54     add_node3->set_name("add_node3");
55     add_node3->set_op("Add");
56     add_node3->add_input("identity_node1");
57     add_node3->add_input("const_node3");
58 
59     NodeDef* identity_node1 = graph_def.add_node();
60     identity_node1->set_name("identity_node1");
61     identity_node1->set_op("Identity");
62     identity_node1->add_input("const_node1");
63 
64     NodeDef* identity_node2 = graph_def.add_node();
65     identity_node2->set_name("identity_node2");
66     identity_node2->set_op("Identity");
67     identity_node2->add_input("const_node2");
68 
69     NodeDef* identity_node3 = graph_def.add_node();
70     identity_node3->set_name("identity_node3");
71     identity_node3->set_op("Identity");
72     identity_node3->add_input("const_node3");
73 
74     NodeDef* const_node1 = graph_def.add_node();
75     const_node1->set_name("const_node1");
76     const_node1->set_op("Const");
77 
78     NodeDef* const_node2 = graph_def.add_node();
79     const_node2->set_name("const_node2");
80     const_node2->set_op("Const");
81 
82     NodeDef* const_node3 = graph_def.add_node();
83     const_node3->set_name("const_node3");
84     const_node3->set_op("Const");
85 
86     NodeDef* add_node4 = graph_def.add_node();
87     add_node4->set_name("add_node4");
88     add_node4->set_op("Add");
89     add_node4->add_input("add_node2");
90     add_node4->add_input("add_node3");
91 
92     GraphDef result;
93     TransformFuncContext context;
94     context.input_names = {};
95     context.output_names = {"add_node1"};
96     context.params.insert(
97         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
98     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
99 
100     std::map<string, const NodeDef*> node_lookup;
101     MapNamesToNodes(result, &node_lookup);
102     EXPECT_EQ(1, node_lookup.count("add_node1"));
103     EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0));
104     EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1));
105     EXPECT_EQ(1, node_lookup.count("add_node2"));
106     EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0));
107     EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1));
108     EXPECT_EQ(1, node_lookup.count("add_node3"));
109     EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0));
110     EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1));
111     EXPECT_EQ(1, node_lookup.count("add_node4"));
112     EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0));
113     EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1));
114     EXPECT_EQ(0, node_lookup.count("identity_node1"));
115     EXPECT_EQ(0, node_lookup.count("identity_node2"));
116     EXPECT_EQ(0, node_lookup.count("identity_node3"));
117     EXPECT_EQ(1, node_lookup.count("const_node1"));
118     EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
119     EXPECT_EQ(1, node_lookup.count("const_node2"));
120     EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
121     EXPECT_EQ(1, node_lookup.count("const_node3"));
122     EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
123   }
124 
TestRemoveOutputNodes()125   void TestRemoveOutputNodes() {
126     GraphDef graph_def;
127 
128     NodeDef* const_node1 = graph_def.add_node();
129     const_node1->set_name("const_node1");
130     const_node1->set_op("Const");
131 
132     NodeDef* const_node2 = graph_def.add_node();
133     const_node2->set_name("const_node2");
134     const_node2->set_op("Const");
135 
136     NodeDef* add_node = graph_def.add_node();
137     add_node->set_name("add_node");
138     add_node->set_op("Add");
139     add_node->add_input("const_node1");
140     add_node->add_input("const_node2");
141 
142     NodeDef* identity_node = graph_def.add_node();
143     identity_node->set_name("identity_node");
144     identity_node->set_op("Identity");
145     identity_node->add_input("add_node");
146 
147     GraphDef result;
148     TransformFuncContext context;
149     context.input_names = {};
150     context.output_names = {"identity_node"};
151     context.params.insert(
152         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
153     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
154 
155     std::map<string, const NodeDef*> node_lookup;
156     MapNamesToNodes(result, &node_lookup);
157     EXPECT_EQ(1, node_lookup.count("add_node"));
158     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
159     EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
160     EXPECT_EQ(1, node_lookup.count("identity_node"));
161     EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0));
162   }
163 
TestRemoveChainedNodes()164   void TestRemoveChainedNodes() {
165     GraphDef graph_def;
166 
167     NodeDef* const_node1 = graph_def.add_node();
168     const_node1->set_name("const_node1");
169     const_node1->set_op("Const");
170 
171     NodeDef* identity_node1 = graph_def.add_node();
172     identity_node1->set_name("identity_node1");
173     identity_node1->set_op("Identity");
174     identity_node1->add_input("const_node1");
175 
176     NodeDef* identity_node2 = graph_def.add_node();
177     identity_node2->set_name("identity_node2");
178     identity_node2->set_op("Identity");
179     identity_node2->add_input("identity_node1");
180 
181     NodeDef* identity_node3 = graph_def.add_node();
182     identity_node3->set_name("identity_node3");
183     identity_node3->set_op("Identity");
184     identity_node3->add_input("identity_node2");
185 
186     NodeDef* const_node2 = graph_def.add_node();
187     const_node2->set_name("const_node2");
188     const_node2->set_op("Const");
189 
190     NodeDef* add_node = graph_def.add_node();
191     add_node->set_name("add_node");
192     add_node->set_op("Add");
193     add_node->add_input("identity_node3");
194     add_node->add_input("const_node2");
195 
196     GraphDef result;
197     TransformFuncContext context;
198     context.input_names = {};
199     context.output_names = {"identity_node"};
200     context.params.insert(
201         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
202     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
203 
204     std::map<string, const NodeDef*> node_lookup;
205     MapNamesToNodes(result, &node_lookup);
206     EXPECT_EQ(1, node_lookup.count("add_node"));
207     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
208     EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
209     EXPECT_EQ(0, node_lookup.count("identity_node1"));
210     EXPECT_EQ(0, node_lookup.count("identity_node2"));
211     EXPECT_EQ(0, node_lookup.count("identity_node3"));
212   }
213 
TestRemoveMultipleInputs()214   void TestRemoveMultipleInputs() {
215     GraphDef graph_def;
216 
217     NodeDef* const_node1 = graph_def.add_node();
218     const_node1->set_name("const_node1");
219     const_node1->set_op("Const");
220 
221     NodeDef* const_node2 = graph_def.add_node();
222     const_node2->set_name("const_node2");
223     const_node2->set_op("Const");
224 
225     NodeDef* const_node3 = graph_def.add_node();
226     const_node3->set_name("const_node3");
227     const_node3->set_op("Const");
228 
229     NodeDef* const_node4 = graph_def.add_node();
230     const_node4->set_name("const_node4");
231     const_node4->set_op("Const");
232 
233     NodeDef* fake_quant_node = graph_def.add_node();
234     fake_quant_node->set_name("fake_quant_node");
235     fake_quant_node->set_op("FakeQuantWithMinMaxVars");
236     fake_quant_node->add_input("const_node1");
237     fake_quant_node->add_input("const_node2");
238     fake_quant_node->add_input("const_node3");
239 
240     NodeDef* add_node = graph_def.add_node();
241     add_node->set_name("add_node");
242     add_node->set_op("Add");
243     add_node->add_input("fake_quant_node");
244     add_node->add_input("const_node4");
245 
246     GraphDef result;
247     TransformFuncContext context;
248     context.input_names = {};
249     context.output_names = {"add_node"};
250     context.params.insert(std::pair<string, std::vector<string>>(
251         {"op", {string("FakeQuantWithMinMaxVars")}}));
252     context.params.insert(
253         std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
254     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
255 
256     std::map<string, const NodeDef*> node_lookup;
257     MapNamesToNodes(result, &node_lookup);
258     ASSERT_EQ(1, node_lookup.count("const_node1"));
259     ASSERT_EQ(1, node_lookup.count("const_node4"));
260     ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
261     ASSERT_EQ(1, node_lookup.count("add_node"));
262     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
263     EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
264   }
265 };
266 
TEST_F(RemoveNodesTest,TestRemoveNodes)267 TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
268 
TEST_F(RemoveNodesTest,TestRemoveOutputNodes)269 TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
270 
TEST_F(RemoveNodesTest,TestRemoveChainedNodes)271 TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
272 
TEST_F(RemoveNodesTest,TestRemoveMultipleInputs)273 TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
274   TestRemoveMultipleInputs();
275 }
276 
277 }  // namespace graph_transforms
278 }  // namespace tensorflow
279