• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28 
29 namespace tensorflow {
30 namespace graph_transforms {
31 
32 // Declare here, so we don't need a public header.
33 Status StripUnusedNodes(const GraphDef& input_graph_def,
34                         const TransformFuncContext& context,
35                         GraphDef* output_graph_def);
36 
37 class StripUnusedNodesTest : public ::testing::Test {
38  protected:
TestSimpleAdd()39   void TestSimpleAdd() {
40     GraphDef graph_def;
41     NodeDef* add_node = graph_def.add_node();
42     add_node->set_name("add_node");
43     add_node->set_op("Add");
44     add_node->add_input("a_node");
45     add_node->add_input("b_node");
46 
47     NodeDef* a_node = graph_def.add_node();
48     a_node->set_name("a_node");
49     a_node->set_op("Const");
50 
51     NodeDef* b_node = graph_def.add_node();
52     b_node->set_name("b_node");
53     b_node->set_op("Const");
54 
55     NodeDef* c_node = graph_def.add_node();
56     c_node->set_name("c_node");
57     c_node->set_op("Const");
58 
59     GraphDef result;
60     TF_ASSERT_OK(StripUnusedNodes(graph_def, {{}, {"add_node"}}, &result));
61 
62     std::map<string, const NodeDef*> node_lookup;
63     MapNamesToNodes(result, &node_lookup);
64     EXPECT_EQ(1, node_lookup.count("add_node"));
65     EXPECT_EQ(1, node_lookup.count("a_node"));
66     EXPECT_EQ(1, node_lookup.count("b_node"));
67     EXPECT_EQ(0, node_lookup.count("c_node"));
68   }
69 
TestCommonAncestor()70   void TestCommonAncestor() {
71     GraphDef graph_def;
72 
73     NodeDef* add_node1 = graph_def.add_node();
74     add_node1->set_name("add_node1");
75     add_node1->set_op("Add");
76     add_node1->add_input("add_node2");
77     add_node1->add_input("add_node3");
78 
79     NodeDef* add_node2 = graph_def.add_node();
80     add_node2->set_name("add_node2");
81     add_node2->set_op("Add");
82     add_node2->add_input("const_node1");
83     add_node2->add_input("const_node2");
84 
85     NodeDef* add_node3 = graph_def.add_node();
86     add_node3->set_name("add_node3");
87     add_node3->set_op("Add");
88     add_node3->add_input("const_node1");
89     add_node3->add_input("const_node3");
90 
91     NodeDef* const_node1 = graph_def.add_node();
92     const_node1->set_name("const_node1");
93     const_node1->set_op("Const");
94 
95     NodeDef* const_node2 = graph_def.add_node();
96     const_node2->set_name("const_node2");
97     const_node2->set_op("Const");
98 
99     NodeDef* const_node3 = graph_def.add_node();
100     const_node3->set_name("const_node3");
101     const_node3->set_op("Const");
102 
103     NodeDef* dangling_input = graph_def.add_node();
104     dangling_input->set_name("dangling_input");
105     dangling_input->set_op("Const");
106 
107     NodeDef* add_node4 = graph_def.add_node();
108     add_node4->set_name("add_node4");
109     add_node4->set_op("Add");
110     add_node4->add_input("add_node2");
111     add_node4->add_input("add_node3");
112 
113     GraphDef result;
114     TF_ASSERT_OK(StripUnusedNodes(
115         graph_def, {{"dangling_input"}, {"add_node1"}}, &result));
116 
117     std::map<string, const NodeDef*> node_lookup;
118     MapNamesToNodes(result, &node_lookup);
119     EXPECT_EQ(1, node_lookup.count("add_node1"));
120     EXPECT_EQ(1, node_lookup.count("add_node2"));
121     EXPECT_EQ(1, node_lookup.count("add_node3"));
122     EXPECT_EQ(0, node_lookup.count("add_node4"));
123     EXPECT_EQ(1, node_lookup.count("const_node1"));
124     EXPECT_EQ(1, node_lookup.count("const_node2"));
125     EXPECT_EQ(1, node_lookup.count("const_node3"));
126     EXPECT_EQ(0, node_lookup.count("const_node4"));
127     EXPECT_EQ(1, node_lookup.count("dangling_input"));
128   }
129 
TestSimplePlaceholder()130   void TestSimplePlaceholder() {
131     GraphDef graph_def;
132     NodeDef* add_node = graph_def.add_node();
133     add_node->set_name("add_node");
134     add_node->set_op("Add");
135     add_node->add_input("mul_node");
136     add_node->add_input("a_node");
137 
138     NodeDef* mul_node = graph_def.add_node();
139     mul_node->set_name("mul_node");
140     mul_node->set_op("Mul");
141     mul_node->add_input("b_node");
142     mul_node->add_input("c_node");
143 
144     NodeDef* a_node = graph_def.add_node();
145     a_node->set_name("a_node");
146     a_node->set_op("Const");
147 
148     NodeDef* b_node = graph_def.add_node();
149     b_node->set_name("b_node");
150     b_node->set_op("Const");
151 
152     NodeDef* c_node = graph_def.add_node();
153     c_node->set_name("c_node");
154     c_node->set_op("Const");
155 
156     GraphDef result;
157     TF_ASSERT_OK(
158         StripUnusedNodes(graph_def, {{"mul_node"}, {"add_node"}}, &result));
159 
160     std::map<string, const NodeDef*> node_lookup;
161     MapNamesToNodes(result, &node_lookup);
162     EXPECT_EQ(1, node_lookup.count("add_node"));
163     EXPECT_EQ(1, node_lookup.count("mul_node"));
164     EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
165     EXPECT_EQ(DT_FLOAT, node_lookup["mul_node"]->attr().at("dtype").type());
166     EXPECT_EQ(TensorShape({}),
167               TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
168     EXPECT_EQ(1, node_lookup.count("a_node"));
169     EXPECT_EQ(0, node_lookup.count("b_node"));
170     EXPECT_EQ(0, node_lookup.count("c_node"));
171   }
172 
TestPlaceholderDefaultArgs()173   void TestPlaceholderDefaultArgs() {
174     GraphDef graph_def;
175     NodeDef* add_node = graph_def.add_node();
176     add_node->set_name("add_node");
177     add_node->set_op("Add");
178     add_node->add_input("mul_node");
179     add_node->add_input("a_node");
180 
181     NodeDef* mul_node = graph_def.add_node();
182     mul_node->set_name("mul_node");
183     mul_node->set_op("Mul");
184     mul_node->add_input("b_node");
185     mul_node->add_input("c_node");
186 
187     NodeDef* a_node = graph_def.add_node();
188     a_node->set_name("a_node");
189     a_node->set_op("Const");
190 
191     NodeDef* b_node = graph_def.add_node();
192     b_node->set_name("b_node");
193     b_node->set_op("Const");
194 
195     NodeDef* c_node = graph_def.add_node();
196     c_node->set_name("c_node");
197     c_node->set_op("Const");
198 
199     GraphDef result;
200     TF_ASSERT_OK(StripUnusedNodes(graph_def,
201                                   {{"mul_node"},
202                                    {"add_node"},
203                                    {{"type", {"int32"}}, {"shape", {"1,2,3"}}}},
204                                   &result));
205 
206     std::map<string, const NodeDef*> node_lookup;
207     MapNamesToNodes(result, &node_lookup);
208     EXPECT_EQ(1, node_lookup.count("add_node"));
209     EXPECT_EQ(1, node_lookup.count("mul_node"));
210     EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
211     EXPECT_EQ(DT_INT32, node_lookup["mul_node"]->attr().at("dtype").type());
212     EXPECT_EQ(TensorShape({1, 2, 3}),
213               TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
214     EXPECT_EQ(1, node_lookup.count("a_node"));
215     EXPECT_EQ(0, node_lookup.count("b_node"));
216     EXPECT_EQ(0, node_lookup.count("c_node"));
217   }
218 
TestPlaceholderNamedArgs()219   void TestPlaceholderNamedArgs() {
220     GraphDef graph_def;
221     NodeDef* add_node = graph_def.add_node();
222     add_node->set_name("add_node");
223     add_node->set_op("Add");
224     add_node->add_input("mul_node");
225     add_node->add_input("a_node");
226 
227     NodeDef* mul_node = graph_def.add_node();
228     mul_node->set_name("mul_node");
229     mul_node->set_op("Mul");
230     mul_node->add_input("b_node");
231     mul_node->add_input("c_node");
232 
233     NodeDef* a_node = graph_def.add_node();
234     a_node->set_name("a_node");
235     a_node->set_op("Const");
236 
237     NodeDef* b_node = graph_def.add_node();
238     b_node->set_name("b_node");
239     b_node->set_op("Const");
240 
241     NodeDef* c_node = graph_def.add_node();
242     c_node->set_name("c_node");
243     c_node->set_op("Const");
244 
245     GraphDef result;
246     TF_ASSERT_OK(StripUnusedNodes(graph_def,
247                                   {{"mul_node", "a_node"},
248                                    {"add_node"},
249                                    {{"name", {"a_node", "mul_node"}},
250                                     {"type_for_name", {"int64", "quint8"}},
251                                     {"shape_for_name", {"1,2", "1, 2, 3"}}}},
252                                   &result));
253 
254     std::map<string, const NodeDef*> node_lookup;
255     MapNamesToNodes(result, &node_lookup);
256     EXPECT_EQ(1, node_lookup.count("add_node"));
257     EXPECT_EQ(1, node_lookup.count("mul_node"));
258     EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
259     EXPECT_EQ(DT_QUINT8, node_lookup["mul_node"]->attr().at("dtype").type());
260     EXPECT_EQ(TensorShape({1, 2, 3}),
261               TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
262     EXPECT_EQ(1, node_lookup.count("a_node"));
263     EXPECT_EQ("Placeholder", node_lookup["a_node"]->op());
264     EXPECT_EQ(DT_INT64, node_lookup["a_node"]->attr().at("dtype").type());
265     EXPECT_EQ(TensorShape({1, 2}),
266               TensorShape(node_lookup["a_node"]->attr().at("shape").shape()));
267     EXPECT_EQ(0, node_lookup.count("b_node"));
268     EXPECT_EQ(0, node_lookup.count("c_node"));
269   }
270 };
271 
TEST_F(StripUnusedNodesTest,TestSimpleAdd)272 TEST_F(StripUnusedNodesTest, TestSimpleAdd) { TestSimpleAdd(); }
273 
TEST_F(StripUnusedNodesTest,TestCommonAncestor)274 TEST_F(StripUnusedNodesTest, TestCommonAncestor) { TestCommonAncestor(); }
275 
TEST_F(StripUnusedNodesTest,TestSimplePlaceholder)276 TEST_F(StripUnusedNodesTest, TestSimplePlaceholder) { TestSimplePlaceholder(); }
277 
TEST_F(StripUnusedNodesTest,TestPlaceholderDefaultArgs)278 TEST_F(StripUnusedNodesTest, TestPlaceholderDefaultArgs) {
279   TestPlaceholderDefaultArgs();
280 }
281 
TEST_F(StripUnusedNodesTest,TestPlaceholderNamedArgs)282 TEST_F(StripUnusedNodesTest, TestPlaceholderNamedArgs) {
283   TestPlaceholderNamedArgs();
284 }
285 
286 }  // namespace graph_transforms
287 }  // namespace tensorflow
288