1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/math_ops.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/lib/core/status_test_util.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/public/session.h"
27 #include "tensorflow/tools/graph_transforms/transform_utils.h"
28
29 namespace tensorflow {
30 namespace graph_transforms {
31
32 // Declare here, so we don't need a public header.
33 Status FoldBatchNorms(const GraphDef& input_graph_def,
34 const TransformFuncContext& context,
35 GraphDef* output_graph_def);
36
37 class FoldBatchNormsTest : public ::testing::Test {
38 protected:
TestFoldBatchNormsConv2D()39 void TestFoldBatchNormsConv2D() {
40 auto root = tensorflow::Scope::NewRootScope();
41 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
42
43 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
44 test::FillValues<float>(
45 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
46 -5.0f, -3.0f, -6.0f});
47 Output input_op =
48 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
49
50 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
51 test::FillValues<float>(&weights_data,
52 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
53 Output weights_op =
54 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
55
56 Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
57 {1, 1, 1, 1}, "VALID");
58
59 Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
60 test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
61 Output mul_values_op = Const(root.WithOpName("mul_values"),
62 Input::Initializer(mul_values_data));
63
64 Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
65
66 GraphDef original_graph_def;
67 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
68
69 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
70 TF_ASSERT_OK(original_session->Create(original_graph_def));
71 std::vector<Tensor> original_outputs;
72 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
73
74 GraphDef fused_graph_def;
75 TF_ASSERT_OK(
76 FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
77
78 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
79 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
80 std::vector<Tensor> fused_outputs;
81 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
82
83 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
84
85 for (const NodeDef& node : fused_graph_def.node()) {
86 EXPECT_NE("Mul", node.op());
87 }
88 }
89
TestFoldBatchNormsDepthwiseConv2dNative()90 void TestFoldBatchNormsDepthwiseConv2dNative() {
91 auto root = tensorflow::Scope::NewRootScope();
92 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
93
94 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
95 test::FillValues<float>(
96 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
97 -5.0f, -3.0f, -6.0f});
98 Output input_op =
99 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
100
101 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
102 test::FillValues<float>(&weights_data,
103 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
104 Output weights_op =
105 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
106
107 Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
108 weights_op, {1, 1, 1, 1}, "VALID");
109
110 Tensor mul_values_data(DT_FLOAT, TensorShape({4}));
111 test::FillValues<float>(&mul_values_data, {2.0f, 3.0f, 4.0f, 5.0f});
112 Output mul_values_op = Const(root.WithOpName("mul_values"),
113 Input::Initializer(mul_values_data));
114
115 Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
116
117 GraphDef original_graph_def;
118 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
119
120 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
121 TF_ASSERT_OK(original_session->Create(original_graph_def));
122 std::vector<Tensor> original_outputs;
123 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
124
125 GraphDef fused_graph_def;
126 TF_ASSERT_OK(
127 FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
128
129 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
130 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
131 std::vector<Tensor> fused_outputs;
132 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
133
134 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
135
136 for (const NodeDef& node : fused_graph_def.node()) {
137 EXPECT_NE("Mul", node.op());
138 }
139 }
140
TestFoldBatchNormsConv2DShared()141 void TestFoldBatchNormsConv2DShared() {
142 auto root = tensorflow::Scope::NewRootScope();
143 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
144
145 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
146 test::FillValues<float>(
147 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
148 -5.0f, -3.0f, -6.0f});
149 Output input_op =
150 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
151
152 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
153 test::FillValues<float>(&weights_data,
154 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
155 Output weights_op =
156 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
157
158 Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
159 {1, 1, 1, 1}, "VALID");
160
161 Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
162 test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
163 Output mul_values_op = Const(root.WithOpName("mul_values"),
164 Input::Initializer(mul_values_data));
165
166 Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
167
168 Tensor mul_values_data_2(DT_FLOAT, TensorShape({2}));
169 test::FillValues<float>(&mul_values_data_2, {1.0f, 2.0f});
170 Output mul_values_op_2 = Const(root.WithOpName("mul_values_2"),
171 Input::Initializer(mul_values_data));
172
173 Output mul_op_2 =
174 Mul(root.WithOpName("output_2"), conv_op, mul_values_op_2);
175
176 GraphDef original_graph_def;
177 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
178
179 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
180 TF_ASSERT_OK(original_session->Create(original_graph_def));
181 std::vector<Tensor> original_outputs;
182 TF_ASSERT_OK(original_session->Run({}, {"output", "output_2"}, {},
183 &original_outputs));
184
185 GraphDef fused_graph_def;
186 TF_ASSERT_OK(FoldBatchNorms(
187 original_graph_def, {{}, {"output", "output_2"}}, &fused_graph_def));
188
189 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
190 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
191 std::vector<Tensor> fused_outputs;
192 TF_ASSERT_OK(
193 fused_session->Run({}, {"output", "output_2"}, {}, &fused_outputs));
194
195 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
196 test::ExpectTensorNear<float>(original_outputs[1], fused_outputs[1], 1e-5);
197 }
198
TestFoldBatchNormsMatMul()199 void TestFoldBatchNormsMatMul() {
200 auto root = tensorflow::Scope::NewRootScope();
201 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
202
203 Tensor input_data(DT_FLOAT, TensorShape({6, 2}));
204 test::FillValues<float>(
205 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
206 -5.0f, -3.0f, -6.0f});
207 Output input_op =
208 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
209
210 Tensor weights_data(DT_FLOAT, TensorShape({2, 2}));
211 test::FillValues<float>(&weights_data, {1.0f, 2.0f, 0.3f, 0.4f});
212 Output weights_op =
213 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
214
215 Output matmul_op =
216 MatMul(root.WithOpName("matmul_op"), input_op, weights_op);
217
218 Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
219 test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
220 Output mul_values_op = Const(root.WithOpName("mul_values"),
221 Input::Initializer(mul_values_data));
222
223 Output mul_op = Mul(root.WithOpName("output"), matmul_op, mul_values_op);
224
225 GraphDef original_graph_def;
226 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
227
228 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
229 TF_ASSERT_OK(original_session->Create(original_graph_def));
230 std::vector<Tensor> original_outputs;
231 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
232
233 GraphDef fused_graph_def;
234 TF_ASSERT_OK(
235 FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
236
237 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
238 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
239 std::vector<Tensor> fused_outputs;
240 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
241
242 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
243
244 for (const NodeDef& node : fused_graph_def.node()) {
245 EXPECT_NE("Mul", node.op());
246 }
247 }
248 };
249
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsConv2D)250 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsConv2D) {
251 TestFoldBatchNormsConv2D();
252 }
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsMatMul)253 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
254 TestFoldBatchNormsMatMul();
255 }
TEST_F(FoldBatchNormsTest,TestFoldBatchNormsDepthwiseConv2dNative)256 TEST_F(FoldBatchNormsTest, TestFoldBatchNormsDepthwiseConv2dNative) {
257 TestFoldBatchNormsDepthwiseConv2dNative();
258 }
259
260 } // namespace graph_transforms
261 } // namespace tensorflow
262