• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/image_ops.h"
20 #include "tensorflow/cc/ops/nn_ops.h"
21 #include "tensorflow/cc/ops/sendrecv_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/test_benchmark.h"
30 #include "tensorflow/core/public/session.h"
31 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
32 #include "tensorflow/tools/graph_transforms/transform_utils.h"
33 
34 namespace tensorflow {
35 namespace graph_transforms {
36 
37 // Declaring this here so it doesn't need to be in the public header.
38 Status ReplaceSendRecvs(const GraphDef& original_graph_def,
39                         const GraphDef& rewritten_graph_def,
40                         const std::vector<string>& inputs,
41                         const std::vector<string>& outputs,
42                         GraphDef* output_graph_def);
43 
44 class ConstantFoldingTest : public ::testing::Test {
45  protected:
TestSimpleAdd()46   void TestSimpleAdd() {
47     auto root = tensorflow::Scope::NewRootScope();
48     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
49 
50     const int width = 100;
51 
52     Tensor a_data(DT_FLOAT, TensorShape({width}));
53     test::FillIota<float>(&a_data, 1.0f);
54     Output a_const =
55         Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
56 
57     Tensor b_data(DT_FLOAT, TensorShape({width}));
58     test::FillIota<float>(&b_data, 1.0f);
59     Output b_const =
60         Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
61 
62     Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
63 
64     Output placeholder =
65         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
66 
67     Output mul =
68         Mul(root.WithOpName("output_expect_remains"), add, placeholder);
69 
70     GraphDef graph_def;
71     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
72 
73     Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
74     test::FillIota<float>(&placeholder_tensor, 1.0f);
75     TestConstantFolding(graph_def,
76                         {{"placeholder_expect_remains", placeholder_tensor}},
77                         {}, {"output_expect_remains"}, {});
78     TestConstantFolding(graph_def,
79                         {{"placeholder_expect_remains:0", placeholder_tensor}},
80                         {}, {"output_expect_remains:0"}, {});
81   }
82 
TestOpExclusionAdd()83   void TestOpExclusionAdd() {
84     auto root = tensorflow::Scope::NewRootScope();
85     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
86 
87     const int width = 100;
88 
89     Tensor a_data(DT_FLOAT, TensorShape({width}));
90     test::FillIota<float>(&a_data, 1.0f);
91     Output a_const =
92         Const(root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
93 
94     Tensor b_data(DT_FLOAT, TensorShape({width}));
95     test::FillIota<float>(&b_data, 1.0f);
96     Output b_const =
97         Const(root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
98 
99     Output add = Add(root.WithOpName("add_expect_remains"), a_const, b_const);
100 
101     Output placeholder =
102         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
103 
104     Output mul =
105         Mul(root.WithOpName("output_expect_remains"), add, placeholder);
106 
107     GraphDef graph_def;
108     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
109 
110     Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
111     test::FillIota<float>(&placeholder_tensor, 1.0f);
112     TestConstantFolding(graph_def,
113                         {{"placeholder_expect_remains", placeholder_tensor}},
114                         {"Add"}, {"output_expect_remains"}, {});
115   }
116 
TestShapePropagation()117   void TestShapePropagation() {
118     auto root = tensorflow::Scope::NewRootScope();
119     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
120 
121     Output placeholder =
122         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
123     Output a_const =
124         Const(root.WithOpName("a_expect_removed"),
125               Input::Initializer({1, 1, 1}, TensorShape({1, 1, 3})));
126     Output shape = Shape(root.WithOpName("shape_expect_removed"), a_const);
127     Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
128     Output mul =
129         Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
130 
131     GraphDef graph_def;
132     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
133 
134     Tensor placeholder_tensor(DT_FLOAT, TensorShape({3}));
135     test::FillIota<float>(&placeholder_tensor, 1.0);
136     TestConstantFolding(graph_def,
137                         {{"placeholder_expect_remains", placeholder_tensor}},
138                         {}, {"output_expect_remains"}, {});
139   }
140 
TestPreserveOutputShapes()141   void TestPreserveOutputShapes() {
142     auto root = tensorflow::Scope::NewRootScope();
143     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
144 
145     tensorflow::AttrValue shape_attr;
146     auto* shape_proto = shape_attr.mutable_list()->add_shape();
147     shape_proto->add_dim()->set_size(1);
148     shape_proto->add_dim()->set_size(1);
149     shape_proto->add_dim()->set_size(3);
150 
151     Output placeholder =
152         Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
153     placeholder.node()->AddAttr("_output_shapes", shape_attr);
154 
155     Output shape = Shape(root.WithOpName("shape_expect_removed"), placeholder);
156     Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
157     Output mul =
158         Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
159 
160     GraphDef graph_def;
161     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
162 
163     Tensor placeholder_tensor(DT_FLOAT, TensorShape({1, 1, 3}));
164     test::FillIota<float>(&placeholder_tensor, 1.0);
165 
166     graph_transforms::TransformFuncContext context;
167     context.params["clear_output_shapes"] = {"false"};
168     TestConstantFolding(graph_def,
169                         {{"placeholder_expect_remains", placeholder_tensor}},
170                         {}, {"output_expect_remains"}, context);
171   }
172 
TestConstantFolding(const GraphDef & graph_def,std::vector<std::pair<string,Tensor>> inputs,std::vector<string> excluded_ops,const std::vector<string> & outputs,graph_transforms::TransformFuncContext context)173   void TestConstantFolding(const GraphDef& graph_def,
174                            std::vector<std::pair<string, Tensor> > inputs,
175                            std::vector<string> excluded_ops,
176                            const std::vector<string>& outputs,
177                            graph_transforms::TransformFuncContext context) {
178     std::unique_ptr<tensorflow::Session> unfolded_session(
179         tensorflow::NewSession(tensorflow::SessionOptions()));
180     TF_ASSERT_OK(unfolded_session->Create(graph_def));
181     std::vector<Tensor> unfolded_tensors;
182     TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
183 
184     GraphDef folded_graph_def;
185     for (const std::pair<string, Tensor>& input : inputs) {
186       context.input_names.push_back(input.first);
187     }
188     context.output_names = outputs;
189     context.params["exclude_op"] = std::move(excluded_ops);
190     TF_ASSERT_OK(
191         graph_transforms::FoldConstants(graph_def, context, &folded_graph_def));
192 
193     std::unique_ptr<tensorflow::Session> folded_session(
194         tensorflow::NewSession(tensorflow::SessionOptions()));
195     TF_ASSERT_OK(folded_session->Create(folded_graph_def));
196     std::vector<Tensor> folded_tensors;
197     TF_ASSERT_OK(folded_session->Run(inputs, outputs, {}, &folded_tensors));
198 
199     EXPECT_EQ(unfolded_tensors.size(), folded_tensors.size());
200     for (int i = 0; i < unfolded_tensors.size(); ++i) {
201       test::ExpectTensorNear<float>(unfolded_tensors[i], folded_tensors[i],
202                                     1e-5);
203     }
204 
205     std::map<string, const NodeDef*> folded_node_map;
206     for (const NodeDef& node : folded_graph_def.node()) {
207       folded_node_map.insert({node.name(), &node});
208     }
209 
210     for (const NodeDef& node : graph_def.node()) {
211       const StringPiece name(node.name());
212       const int occurrence_count = folded_node_map.count(node.name());
213       if (str_util::EndsWith(name, "expect_removed")) {
214         EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
215       }
216       if (str_util::EndsWith(name, "expect_remains")) {
217         EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
218       }
219     }
220   }
221 
TestReplaceSendRecvs()222   void TestReplaceSendRecvs() {
223     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
224 
225     const int width = 100;
226     Tensor a_const_data(DT_FLOAT, TensorShape({width}));
227     test::FillIota<float>(&a_const_data, 1.0f);
228 
229     auto o_root = tensorflow::Scope::NewRootScope();
230     _Recv(o_root.WithOpName("original_recv"), DT_FLOAT, "", "", 0, "");
231     Output o_a_const =
232         Const(o_root.WithOpName("a_const"), Input::Initializer(a_const_data));
233     Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
234     _Send(o_root.WithOpName("original_send"), o_a_const, "", "", 0, "");
235     GraphDef o_graph_def;
236     TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
237 
238     auto n_root = tensorflow::Scope::NewRootScope();
239     _Recv(n_root.WithOpName("original_recv"), DT_FLOAT, "", "", 0, "");
240     Output n_a_const =
241         Const(n_root.WithOpName("a_const"), Input::Initializer(a_const_data));
242     _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "", 0, "");
243     _Send(n_root.WithOpName("original_send"), n_a_const, "", "", 0, "");
244     _Send(n_root.WithOpName("new_send"), n_a_const, "", "", 0, "");
245     GraphDef n_graph_def;
246     TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
247 
248     GraphDef result_graph_def;
249     TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
250         o_graph_def, n_graph_def, {"placeholder"}, {"a_const"},
251         &result_graph_def));
252 
253     std::map<string, const NodeDef*> node_map;
254     graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
255     EXPECT_EQ(1, node_map.count("original_recv"));
256     EXPECT_EQ(1, node_map.count("a_const"));
257     EXPECT_EQ(1, node_map.count("placeholder"));
258     EXPECT_EQ(1, node_map.count("original_send"));
259     EXPECT_EQ(0, node_map.count("_recv_placeholder_0"));
260     EXPECT_EQ(0, node_map.count("new_send"));
261   }
262 
TestReplaceSendRecvsPrefixNames()263   void TestReplaceSendRecvsPrefixNames() {
264     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
265 
266     auto o_root = tensorflow::Scope::NewRootScope();
267     auto a = Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
268     auto b = Placeholder(o_root.WithOpName("placeholder_1"), DT_FLOAT);
269     auto add_o = Add(o_root.WithOpName("add"), a, b);
270     GraphDef o_graph_def;
271     TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
272 
273     auto n_root = tensorflow::Scope::NewRootScope();
274     auto c = _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "",
275                    0, "");
276     auto d = _Recv(n_root.WithOpName("_recv_placeholder_1_0"), DT_FLOAT, "", "",
277                    0, "");
278     auto add_n = Add(n_root.WithOpName("add"), c, d);
279     GraphDef n_graph_def;
280     TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
281 
282     GraphDef result_graph_def;
283     TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
284         o_graph_def, n_graph_def, {"placeholder", "placeholder_1"}, {"add"},
285         &result_graph_def));
286 
287     std::map<string, const NodeDef*> node_map;
288     graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
289     EXPECT_EQ(1, node_map.count("placeholder"));
290     EXPECT_EQ(1, node_map.count("placeholder_1"));
291     EXPECT_EQ(1, node_map.count("add"));
292   }
293 
TestRemoveUnusedNodes()294   void TestRemoveUnusedNodes() {
295     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
296     auto root = tensorflow::Scope::NewRootScope();
297 
298     const int width = 100;
299 
300     Tensor a_data(DT_FLOAT, TensorShape({width}));
301     test::FillIota<float>(&a_data, 1.0f);
302     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
303 
304     Tensor b_data(DT_FLOAT, TensorShape({width}));
305     test::FillIota<float>(&b_data, 1.0f);
306     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
307 
308     Output add = Add(root.WithOpName("add"), a_const, b_const);
309     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
310     Output mul = Mul(root.WithOpName("output"), add, placeholder);
311 
312     Tensor unused_data(DT_FLOAT, TensorShape({width}));
313     test::FillIota<float>(&unused_data, 1.0f);
314     Output unused_const =
315         Const(root.WithOpName("unused"), Input::Initializer(unused_data));
316 
317     GraphDef graph_def;
318     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
319     GraphDef result_graph_def;
320     TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
321         graph_def, {{"placeholder"}, {"output"}}, &result_graph_def));
322 
323     std::map<string, const NodeDef*> node_map;
324     graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
325     EXPECT_EQ(1, node_map.count("a"));
326     EXPECT_EQ(1, node_map.count("b"));
327     EXPECT_EQ(1, node_map.count("add"));
328     EXPECT_EQ(1, node_map.count("placeholder"));
329     EXPECT_EQ(1, node_map.count("output"));
330     EXPECT_EQ(0, node_map.count("unused"));
331   }
332 
TestMaxConstantSizeInBytes()333   void TestMaxConstantSizeInBytes() {
334     auto root = tensorflow::Scope::NewRootScope();
335 
336     const int width = 100;
337 
338     Tensor a_data(DT_FLOAT, TensorShape({width}));
339     test::FillIota<float>(&a_data, 1.0f);
340     Output a_const = ::tensorflow::ops::Const(
341         root.WithOpName("a_expect_remains"), Input::Initializer(a_data));
342 
343     Tensor b_data(DT_FLOAT, TensorShape({width}));
344     test::FillIota<float>(&b_data, 1.0f);
345     Output b_const = ::tensorflow::ops::Const(
346         root.WithOpName("b_expect_remains"), Input::Initializer(b_data));
347 
348     Output add = ::tensorflow::ops::Add(root.WithOpName("add_expect_remains"),
349                                         a_const, b_const);
350 
351     Output placeholder = ::tensorflow::ops::Placeholder(
352         root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
353 
354     Output mul = ::tensorflow::ops::Mul(
355         root.WithOpName("output_expect_remains"), add, placeholder);
356 
357     GraphDef graph_def;
358     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
359 
360     Tensor placeholder_tensor(DT_FLOAT, TensorShape({width}));
361     test::FillIota<float>(&placeholder_tensor, 1.0f);
362 
363     // Setting the maximum constant size to 10 bytes should stop the constant
364     // folding at add(a, b) that would have yielded a constant of
365     // 100*sizeof(float) bytes.
366     graph_transforms::TransformFuncContext context;
367     context.params["max_constant_size_in_bytes"] = {"10"};
368     TestConstantFolding(graph_def,
369                         {{"placeholder_expect_remains", placeholder_tensor}},
370                         {}, {"output_expect_remains"}, context);
371   }
372 };
373 
TEST_F(ConstantFoldingTest,TestSimpleAdd)374 TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
375 
TEST_F(ConstantFoldingTest,TestOpExclusionAdd)376 TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); }
377 
TEST_F(ConstantFoldingTest,TestShapePropagation)378 TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); }
379 
TEST_F(ConstantFoldingTest,TestPreserveOutputShapes)380 TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) {
381   TestPreserveOutputShapes();
382 }
383 
TEST_F(ConstantFoldingTest,TestReplaceSendRecvs)384 TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
385 
TEST_F(ConstantFoldingTest,TestReplaceSendRecvsPrefixNames)386 TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
387   TestReplaceSendRecvsPrefixNames();
388 }
389 
TEST_F(ConstantFoldingTest,TestRemoveUnusedNodes)390 TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
391 
TEST_F(ConstantFoldingTest,TestMaxConstantSizeInBytes)392 TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
393   TestMaxConstantSizeInBytes();
394 }
395 
396 }  // namespace graph_transforms
397 }  // namespace tensorflow
398