1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status FuseResizePadAndConv(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35 Status FuseResizeAndConv(const GraphDef& input_graph_def,
36 const TransformFuncContext& context,
37 GraphDef* output_graph_def);
38 Status FusePadAndConv(const GraphDef& input_graph_def,
39 const TransformFuncContext& context,
40 GraphDef* output_graph_def);
41
42 class FuseConvolutionsTest : public ::testing::Test {
43 protected:
TestFuseResizePadAndConv()44 void TestFuseResizePadAndConv() {
45 auto root = tensorflow::Scope::NewRootScope();
46 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
47
48 Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
49 test::FillValues<float>(
50 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
51 -5.0f, -3.0f, -6.0f});
52 Output input_op =
53 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
54
55 Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
56 Const(root.WithOpName("size"), {12, 4}),
57 ResizeBilinear::AlignCorners(false));
58
59 Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
60 test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
61 Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
62 Input::Initializer(pad_dims_data));
63 Output pad_op =
64 MirrorPad(root.WithOpName("pad_op"), resize_op, pad_dims_op, "REFLECT");
65
66 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
67 test::FillValues<float>(&weights_data,
68 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
69 Output weights_op =
70 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
71
72 Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
73 {1, 1, 1, 1}, "VALID");
74
75 GraphDef original_graph_def;
76 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
77
78 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
79 TF_ASSERT_OK(original_session->Create(original_graph_def));
80 std::vector<Tensor> original_outputs;
81 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
82
83 GraphDef fused_graph_def;
84 TF_ASSERT_OK(FuseResizePadAndConv(original_graph_def, {{}, {"output"}},
85 &fused_graph_def));
86
87 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
88 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
89 std::vector<Tensor> fused_outputs;
90 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
91
92 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
93
94 for (const NodeDef& node : fused_graph_def.node()) {
95 EXPECT_NE("Conv2D", node.op());
96 EXPECT_NE("MirrorPad", node.op());
97 EXPECT_NE("ResizeBilinear", node.op());
98 }
99 }
100
TestFuseResizeAndConv()101 void TestFuseResizeAndConv() {
102 auto root = tensorflow::Scope::NewRootScope();
103 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
104
105 Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
106 test::FillValues<float>(
107 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
108 -5.0f, -3.0f, -6.0f});
109 Output input_op =
110 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
111
112 Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
113 Const(root.WithOpName("size"), {12, 4}),
114 ResizeBilinear::AlignCorners(false));
115
116 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
117 test::FillValues<float>(&weights_data,
118 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
119 Output weights_op =
120 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
121
122 Output conv_op = Conv2D(root.WithOpName("output"), resize_op, weights_op,
123 {1, 1, 1, 1}, "VALID");
124
125 GraphDef original_graph_def;
126 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
127
128 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
129 TF_ASSERT_OK(original_session->Create(original_graph_def));
130 std::vector<Tensor> original_outputs;
131 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
132
133 GraphDef fused_graph_def;
134 TF_ASSERT_OK(FuseResizeAndConv(original_graph_def, {{}, {"output"}},
135 &fused_graph_def));
136
137 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
138 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
139 std::vector<Tensor> fused_outputs;
140 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
141
142 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
143
144 for (const NodeDef& node : fused_graph_def.node()) {
145 EXPECT_NE("Conv2D", node.op());
146 EXPECT_NE("ResizeBilinear", node.op());
147 }
148 }
149
TestFusePadAndConv()150 void TestFusePadAndConv() {
151 auto root = tensorflow::Scope::NewRootScope();
152 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
153
154 Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
155 test::FillValues<float>(
156 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
157 -5.0f, -3.0f, -6.0f});
158 Output input_op =
159 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
160
161 Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
162 test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
163 Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
164 Input::Initializer(pad_dims_data));
165 Output pad_op =
166 MirrorPad(root.WithOpName("pad_op"), input_op, pad_dims_op, "REFLECT");
167
168 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
169 test::FillValues<float>(&weights_data,
170 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
171 Output weights_op =
172 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
173
174 Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
175 {1, 1, 1, 1}, "VALID");
176
177 GraphDef original_graph_def;
178 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
179
180 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
181 TF_ASSERT_OK(original_session->Create(original_graph_def));
182 std::vector<Tensor> original_outputs;
183 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
184
185 GraphDef fused_graph_def;
186 TF_ASSERT_OK(
187 FusePadAndConv(original_graph_def, {{}, {"output"}}, &fused_graph_def));
188
189 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
190 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
191 std::vector<Tensor> fused_outputs;
192 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
193
194 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
195
196 for (const NodeDef& node : fused_graph_def.node()) {
197 EXPECT_NE("Conv2D", node.op());
198 EXPECT_NE("MirrorPad", node.op());
199 }
200 }
201 };
202
TEST_F(FuseConvolutionsTest,TestFuseResizePadAndConv)203 TEST_F(FuseConvolutionsTest, TestFuseResizePadAndConv) {
204 TestFuseResizePadAndConv();
205 }
206
TEST_F(FuseConvolutionsTest,TestFuseResizeAndConv)207 TEST_F(FuseConvolutionsTest, TestFuseResizeAndConv) { TestFuseResizeAndConv(); }
208
TEST_F(FuseConvolutionsTest,TestFusePadAndConv)209 TEST_F(FuseConvolutionsTest, TestFusePadAndConv) { TestFusePadAndConv(); }
210
211 } // namespace graph_transforms
212 } // namespace tensorflow
213