1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/image_ops.h"
20 #include "tensorflow/cc/ops/nn_ops.h"
21 #include "tensorflow/cc/ops/sendrecv_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/test_benchmark.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
32 #include "tensorflow/tools/graph_transforms/transform_utils.h"
33
34 namespace tensorflow {
35 namespace graph_transforms {
36
37 // Declaring this here so it doesn't need to be in the public header.
38 Status ReplaceSendRecvs(const GraphDef& original_graph_def,
39 const GraphDef& rewritten_graph_def,
40 const std::vector<string>& inputs,
41 const std::vector<string>& outputs,
42 GraphDef* output_graph_def);
43
44 class ConstantFoldingTest : public ::testing::Test {
45 protected:
TestSimpleAdd()46 void TestSimpleAdd() {
47 auto root = tensorflow::Scope::NewRootScope();
48 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
49
50 const int width = 100;
51
52 Tensor a_data(DT_FLOAT, TensorShape({width}));
53 test::FillIota<float>(&a_data, 1.0f);
54 Output a_const =
55 Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
56
57 Tensor b_data(DT_FLOAT, TensorShape({width}));
58 test::FillIota<float>(&b_data, 1.0f);
59 Output b_const =
60 Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
61
62 Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
63
64 Output placeholder =
65 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
66
67 Output mul =
68 Mul(root.WithOpName("output_expect_remains"), add, placeholder);
69
70 GraphDef graph_def;
71 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
72
73 Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
74 test::FillIota<float>(&placeholder_tensor, 1.0f);
75 TestConstantFolding(graph_def,
76 {{"placeholder_expect_remains", placeholder_tensor}},
77 {}, {"output_expect_remains"}, {});
78 TestConstantFolding(graph_def,
79 {{"placeholder_expect_remains:0", placeholder_tensor}},
80 {}, {"output_expect_remains:0"}, {});
81 }
82
TestOpExclusionAdd()83 void TestOpExclusionAdd() {
84 auto root = tensorflow::Scope::NewRootScope();
85 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
86
87 const int width = 100;
88
89 Tensor a_data(DT_FLOAT, TensorShape({width}));
90 test::FillIota<float>(&a_data, 1.0f);
91 Output a_const =
92 Const(root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
93
94 Tensor b_data(DT_FLOAT, TensorShape({width}));
95 test::FillIota<float>(&b_data, 1.0f);
96 Output b_const =
97 Const(root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
98
99 Output add = Add(root.WithOpName("add_expect_remains"), a_const, b_const);
100
101 Output placeholder =
102 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
103
104 Output mul =
105 Mul(root.WithOpName("output_expect_remains"), add, placeholder);
106
107 GraphDef graph_def;
108 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
109
110 Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
111 test::FillIota<float>(&placeholder_tensor, 1.0f);
112 TestConstantFolding(graph_def,
113 {{"placeholder_expect_remains", placeholder_tensor}},
114 {"Add"}, {"output_expect_remains"}, {});
115 }
116
TestShapePropagation()117 void TestShapePropagation() {
118 auto root = tensorflow::Scope::NewRootScope();
119 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
120
121 Output placeholder =
122 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
123 Output a_const =
124 Const(root.WithOpName("a_expect_removed"),
125 Input::Initializer({1, 1, 1}, TensorShape({1, 1, 3})));
126 Output shape = Shape(root.WithOpName("shape_expect_removed"), a_const);
127 Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
128 Output mul =
129 Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
130
131 GraphDef graph_def;
132 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
133
134 Tensor placeholder_tensor(DT_FLOAT, TensorShape({3}));
135 test::FillIota<float>(&placeholder_tensor, 1.0);
136 TestConstantFolding(graph_def,
137 {{"placeholder_expect_remains", placeholder_tensor}},
138 {}, {"output_expect_remains"}, {});
139 }
140
TestPreserveOutputShapes()141 void TestPreserveOutputShapes() {
142 auto root = tensorflow::Scope::NewRootScope();
143 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
144
145 tensorflow::AttrValue shape_attr;
146 auto* shape_proto = shape_attr.mutable_list()->add_shape();
147 shape_proto->add_dim()->set_size(1);
148 shape_proto->add_dim()->set_size(1);
149 shape_proto->add_dim()->set_size(3);
150
151 Output placeholder =
152 Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
153 placeholder.node()->AddAttr("_output_shapes", shape_attr);
154
155 Output shape = Shape(root.WithOpName("shape_expect_removed"), placeholder);
156 Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
157 Output mul =
158 Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
159
160 GraphDef graph_def;
161 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
162
163 Tensor placeholder_tensor(DT_FLOAT, TensorShape({1, 1, 3}));
164 test::FillIota<float>(&placeholder_tensor, 1.0);
165
166 graph_transforms::TransformFuncContext context;
167 context.params["clear_output_shapes"] = {"false"};
168 TestConstantFolding(graph_def,
169 {{"placeholder_expect_remains", placeholder_tensor}},
170 {}, {"output_expect_remains"}, context);
171 }
172
TestConstantFolding(const GraphDef & graph_def,std::vector<std::pair<string,Tensor>> inputs,std::vector<string> excluded_ops,const std::vector<string> & outputs,graph_transforms::TransformFuncContext context)173 void TestConstantFolding(const GraphDef& graph_def,
174 std::vector<std::pair<string, Tensor> > inputs,
175 std::vector<string> excluded_ops,
176 const std::vector<string>& outputs,
177 graph_transforms::TransformFuncContext context) {
178 std::unique_ptr<tensorflow::Session> unfolded_session(
179 tensorflow::NewSession(tensorflow::SessionOptions()));
180 TF_ASSERT_OK(unfolded_session->Create(graph_def));
181 std::vector<Tensor> unfolded_tensors;
182 TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
183
184 GraphDef folded_graph_def;
185 for (const std::pair<string, Tensor>& input : inputs) {
186 context.input_names.push_back(input.first);
187 }
188 context.output_names = outputs;
189 context.params["exclude_op"] = std::move(excluded_ops);
190 TF_ASSERT_OK(
191 graph_transforms::FoldConstants(graph_def, context, &folded_graph_def));
192
193 std::unique_ptr<tensorflow::Session> folded_session(
194 tensorflow::NewSession(tensorflow::SessionOptions()));
195 TF_ASSERT_OK(folded_session->Create(folded_graph_def));
196 std::vector<Tensor> folded_tensors;
197 TF_ASSERT_OK(folded_session->Run(inputs, outputs, {}, &folded_tensors));
198
199 EXPECT_EQ(unfolded_tensors.size(), folded_tensors.size());
200 for (int i = 0; i < unfolded_tensors.size(); ++i) {
201 test::ExpectTensorNear<float>(unfolded_tensors[i], folded_tensors[i],
202 1e-5);
203 }
204
205 std::map<string, const NodeDef*> folded_node_map;
206 for (const NodeDef& node : folded_graph_def.node()) {
207 folded_node_map.insert({node.name(), &node});
208 }
209
210 for (const NodeDef& node : graph_def.node()) {
211 const StringPiece name(node.name());
212 const int occurrence_count = folded_node_map.count(node.name());
213 if (str_util::EndsWith(name, "expect_removed")) {
214 EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
215 }
216 if (str_util::EndsWith(name, "expect_remains")) {
217 EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
218 }
219 }
220 }
221
TestReplaceSendRecvs()222 void TestReplaceSendRecvs() {
223 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
224
225 const int width = 100;
226 Tensor a_const_data(DT_FLOAT, TensorShape({width}));
227 test::FillIota<float>(&a_const_data, 1.0f);
228
229 auto o_root = tensorflow::Scope::NewRootScope();
230 _Recv(o_root.WithOpName("original_recv"), DT_FLOAT, "", "", 0, "");
231 Output o_a_const =
232 Const(o_root.WithOpName("a_const"), Input::Initializer(a_const_data));
233 Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
234 _Send(o_root.WithOpName("original_send"), o_a_const, "", "", 0, "");
235 GraphDef o_graph_def;
236 TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
237
238 auto n_root = tensorflow::Scope::NewRootScope();
239 _Recv(n_root.WithOpName("original_recv"), DT_FLOAT, "", "", 0, "");
240 Output n_a_const =
241 Const(n_root.WithOpName("a_const"), Input::Initializer(a_const_data));
242 _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "", 0, "");
243 _Send(n_root.WithOpName("original_send"), n_a_const, "", "", 0, "");
244 _Send(n_root.WithOpName("new_send"), n_a_const, "", "", 0, "");
245 GraphDef n_graph_def;
246 TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
247
248 GraphDef result_graph_def;
249 TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
250 o_graph_def, n_graph_def, {"placeholder"}, {"a_const"},
251 &result_graph_def));
252
253 std::map<string, const NodeDef*> node_map;
254 graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
255 EXPECT_EQ(1, node_map.count("original_recv"));
256 EXPECT_EQ(1, node_map.count("a_const"));
257 EXPECT_EQ(1, node_map.count("placeholder"));
258 EXPECT_EQ(1, node_map.count("original_send"));
259 EXPECT_EQ(0, node_map.count("_recv_placeholder_0"));
260 EXPECT_EQ(0, node_map.count("new_send"));
261 }
262
TestReplaceSendRecvsPrefixNames()263 void TestReplaceSendRecvsPrefixNames() {
264 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
265
266 auto o_root = tensorflow::Scope::NewRootScope();
267 auto a = Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
268 auto b = Placeholder(o_root.WithOpName("placeholder_1"), DT_FLOAT);
269 auto add_o = Add(o_root.WithOpName("add"), a, b);
270 GraphDef o_graph_def;
271 TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
272
273 auto n_root = tensorflow::Scope::NewRootScope();
274 auto c = _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "",
275 0, "");
276 auto d = _Recv(n_root.WithOpName("_recv_placeholder_1_0"), DT_FLOAT, "", "",
277 0, "");
278 auto add_n = Add(n_root.WithOpName("add"), c, d);
279 GraphDef n_graph_def;
280 TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
281
282 GraphDef result_graph_def;
283 TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
284 o_graph_def, n_graph_def, {"placeholder", "placeholder_1"}, {"add"},
285 &result_graph_def));
286
287 std::map<string, const NodeDef*> node_map;
288 graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
289 EXPECT_EQ(1, node_map.count("placeholder"));
290 EXPECT_EQ(1, node_map.count("placeholder_1"));
291 EXPECT_EQ(1, node_map.count("add"));
292 }
293
TestRemoveUnusedNodes()294 void TestRemoveUnusedNodes() {
295 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
296 auto root = tensorflow::Scope::NewRootScope();
297
298 const int width = 100;
299
300 Tensor a_data(DT_FLOAT, TensorShape({width}));
301 test::FillIota<float>(&a_data, 1.0f);
302 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
303
304 Tensor b_data(DT_FLOAT, TensorShape({width}));
305 test::FillIota<float>(&b_data, 1.0f);
306 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
307
308 Output add = Add(root.WithOpName("add"), a_const, b_const);
309 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
310 Output mul = Mul(root.WithOpName("output"), add, placeholder);
311
312 Tensor unused_data(DT_FLOAT, TensorShape({width}));
313 test::FillIota<float>(&unused_data, 1.0f);
314 Output unused_const =
315 Const(root.WithOpName("unused"), Input::Initializer(unused_data));
316
317 GraphDef graph_def;
318 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
319 GraphDef result_graph_def;
320 TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
321 graph_def, {{"placeholder"}, {"output"}}, &result_graph_def));
322
323 std::map<string, const NodeDef*> node_map;
324 graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
325 EXPECT_EQ(1, node_map.count("a"));
326 EXPECT_EQ(1, node_map.count("b"));
327 EXPECT_EQ(1, node_map.count("add"));
328 EXPECT_EQ(1, node_map.count("placeholder"));
329 EXPECT_EQ(1, node_map.count("output"));
330 EXPECT_EQ(0, node_map.count("unused"));
331 }
332
TestMaxConstantSizeInBytes()333 void TestMaxConstantSizeInBytes() {
334 auto root = tensorflow::Scope::NewRootScope();
335
336 const int width = 100;
337
338 Tensor a_data(DT_FLOAT, TensorShape({width}));
339 test::FillIota<float>(&a_data, 1.0f);
340 Output a_const = ::tensorflow::ops::Const(
341 root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
342
343 Tensor b_data(DT_FLOAT, TensorShape({width}));
344 test::FillIota<float>(&b_data, 1.0f);
345 Output b_const = ::tensorflow::ops::Const(
346 root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
347
348 Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"),
349 a_const, b_const);
350
351 Output placeholder = ::tensorflow::ops::Placeholder(
352 root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
353
354 Output mul = ::tensorflow::ops::Mul(
355 root.WithOpName("output_expect_remains"), add, placeholder);
356
357 GraphDef graph_def;
358 TF_ASSERT_OK(root.ToGraphDef(&graph_def));
359
360 Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
361 test::FillIota<float>(&placeholder_tensor, 1.0f);
362
363 // Setting the maximum constant size to 10 bytes should stop the constant
364 // folding at add(a, b) that would have yielded a constant of
365 // 100*sizeof(float) bytes.
366 graph_transforms::TransformFuncContext context;
367 context.params["max_constant_size_in_bytes"] = {"10"};
368 TestConstantFolding(graph_def,
369 {{"placeholder_expect_remains", placeholder_tensor}},
370 {}, {"output_expect_remains"}, context);
371 }
372 };
373
TEST_F(ConstantFoldingTest,TestSimpleAdd)374 TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
375
TEST_F(ConstantFoldingTest,TestOpExclusionAdd)376 TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); }
377
TEST_F(ConstantFoldingTest,TestShapePropagation)378 TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); }
379
TEST_F(ConstantFoldingTest,TestPreserveOutputShapes)380 TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) {
381 TestPreserveOutputShapes();
382 }
383
TEST_F(ConstantFoldingTest,TestReplaceSendRecvs)384 TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
385
TEST_F(ConstantFoldingTest,TestReplaceSendRecvsPrefixNames)386 TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
387 TestReplaceSendRecvsPrefixNames();
388 }
389
TEST_F(ConstantFoldingTest,TestRemoveUnusedNodes)390 TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
391
TEST_F(ConstantFoldingTest,TestMaxConstantSizeInBytes)392 TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
393 TestMaxConstantSizeInBytes();
394 }
395
396 } // namespace graph_transforms
397 } // namespace tensorflow
398