1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/sendrecv_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/public/session.h"
28 #include "tensorflow/tools/graph_transforms/transform_utils.h"
29
30 namespace tensorflow {
31 namespace graph_transforms {
32
33 // Declare here, so we don't need a public header.
34 Status FoldOldBatchNorms(const GraphDef& input_graph_def,
35 const TransformFuncContext& context,
36 GraphDef* output_graph_def);
37
38 class FoldOldBatchNormsTest : public ::testing::Test {
39 protected:
TestFoldOldBatchNorms()40 void TestFoldOldBatchNorms() {
41 auto root = tensorflow::Scope::NewRootScope();
42 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
43
44 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
45 test::FillValues<float>(
46 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
47 -5.0f, -3.0f, -6.0f});
48 Output input_op =
49 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
50
51 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
52 test::FillValues<float>(&weights_data,
53 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
54 Output weights_op =
55 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
56
57 Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
58 {1, 1, 1, 1}, "VALID");
59
60 Tensor mean_data(DT_FLOAT, TensorShape({2}));
61 test::FillValues<float>(&mean_data, {10.0f, 20.0f});
62 Output mean_op =
63 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
64
65 Tensor variance_data(DT_FLOAT, TensorShape({2}));
66 test::FillValues<float>(&variance_data, {0.25f, 0.5f});
67 Output variance_op = Const(root.WithOpName("variance_op"),
68 Input::Initializer(variance_data));
69
70 Tensor beta_data(DT_FLOAT, TensorShape({2}));
71 test::FillValues<float>(&beta_data, {0.1f, 0.6f});
72 Output beta_op =
73 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
74
75 Tensor gamma_data(DT_FLOAT, TensorShape({2}));
76 test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
77 Output gamma_op =
78 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
79
80 GraphDef original_graph_def;
81 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
82
83 // This is needed because we're trying to convert over a deprecated op which
84 // should only be present in older GraphDef files. Without this we see a
85 // deprecation error.
86 // This is justified because we're trying to test a tool that is expected to
87 // run on legacy files, to help users convert over to less problematic
88 // versions.
89 NodeDef batch_norm_node;
90 batch_norm_node.set_op("BatchNormWithGlobalNormalization");
91 batch_norm_node.set_name("output");
92 AddNodeInput("conv_op", &batch_norm_node);
93 AddNodeInput("mean_op", &batch_norm_node);
94 AddNodeInput("variance_op", &batch_norm_node);
95 AddNodeInput("beta_op", &batch_norm_node);
96 AddNodeInput("gamma_op", &batch_norm_node);
97 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
98 SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
99 SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
100 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
101 original_graph_def.mutable_versions()->set_producer(8);
102
103 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
104 TF_ASSERT_OK(original_session->Create(original_graph_def));
105 std::vector<Tensor> original_outputs;
106 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
107
108 GraphDef fused_graph_def;
109 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
110 &fused_graph_def));
111
112 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
113 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
114 std::vector<Tensor> fused_outputs;
115 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
116
117 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
118
119 for (const NodeDef& node : fused_graph_def.node()) {
120 EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
121 }
122 }
123
TestFoldOldBatchNormsAfterDepthwiseConv2dNative()124 void TestFoldOldBatchNormsAfterDepthwiseConv2dNative() {
125 auto root = tensorflow::Scope::NewRootScope();
126 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
127
128 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
129 test::FillValues<float>(
130 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
131 -5.0f, -3.0f, -6.0f});
132 Output input_op =
133 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
134
135 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
136 test::FillValues<float>(&weights_data,
137 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
138 Output weights_op =
139 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
140
141 Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
142 weights_op, {1, 1, 1, 1}, "VALID");
143
144 Tensor mean_data(DT_FLOAT, TensorShape({4}));
145 test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
146 Output mean_op =
147 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
148
149 Tensor variance_data(DT_FLOAT, TensorShape({4}));
150 test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
151 Output variance_op = Const(root.WithOpName("variance_op"),
152 Input::Initializer(variance_data));
153
154 Tensor beta_data(DT_FLOAT, TensorShape({4}));
155 test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
156 Output beta_op =
157 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
158
159 Tensor gamma_data(DT_FLOAT, TensorShape({4}));
160 test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
161 Output gamma_op =
162 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
163
164 GraphDef original_graph_def;
165 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
166
167 NodeDef batch_norm_node;
168 batch_norm_node.set_op("BatchNormWithGlobalNormalization");
169 batch_norm_node.set_name("output");
170 AddNodeInput("conv_op", &batch_norm_node);
171 AddNodeInput("mean_op", &batch_norm_node);
172 AddNodeInput("variance_op", &batch_norm_node);
173 AddNodeInput("beta_op", &batch_norm_node);
174 AddNodeInput("gamma_op", &batch_norm_node);
175 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
176 SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
177 SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
178 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
179 original_graph_def.mutable_versions()->set_producer(8);
180
181 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
182 TF_ASSERT_OK(original_session->Create(original_graph_def));
183 std::vector<Tensor> original_outputs;
184 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
185
186 GraphDef fused_graph_def;
187 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
188 &fused_graph_def));
189
190 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
191 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
192 std::vector<Tensor> fused_outputs;
193 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
194
195 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
196
197 for (const NodeDef& node : fused_graph_def.node()) {
198 EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
199 }
200 }
201
TestFoldFusedBatchNorms()202 void TestFoldFusedBatchNorms() {
203 auto root = tensorflow::Scope::NewRootScope();
204 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
205
206 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
207 test::FillValues<float>(
208 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
209 -5.0f, -3.0f, -6.0f});
210 Output input_op =
211 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
212
213 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
214 test::FillValues<float>(&weights_data,
215 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
216 Output weights_op =
217 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
218
219 Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
220 {1, 1, 1, 1}, "VALID");
221
222 Tensor mean_data(DT_FLOAT, TensorShape({2}));
223 test::FillValues<float>(&mean_data, {10.0f, 20.0f});
224 Output mean_op =
225 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
226
227 Tensor variance_data(DT_FLOAT, TensorShape({2}));
228 test::FillValues<float>(&variance_data, {0.25f, 0.5f});
229 Output variance_op = Const(root.WithOpName("variance_op"),
230 Input::Initializer(variance_data));
231
232 Tensor beta_data(DT_FLOAT, TensorShape({2}));
233 test::FillValues<float>(&beta_data, {0.1f, 0.6f});
234 Output beta_op =
235 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
236
237 Tensor gamma_data(DT_FLOAT, TensorShape({2}));
238 test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
239 Output gamma_op =
240 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
241
242 GraphDef original_graph_def;
243 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
244
245 NodeDef batch_norm_node;
246 batch_norm_node.set_op("FusedBatchNorm");
247 batch_norm_node.set_name("output");
248 AddNodeInput("conv_op", &batch_norm_node);
249 AddNodeInput("gamma_op", &batch_norm_node);
250 AddNodeInput("beta_op", &batch_norm_node);
251 AddNodeInput("mean_op", &batch_norm_node);
252 AddNodeInput("variance_op", &batch_norm_node);
253 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
254 SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
255 SetNodeAttr("is_training", false, &batch_norm_node);
256 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
257
258 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
259 TF_ASSERT_OK(original_session->Create(original_graph_def));
260 std::vector<Tensor> original_outputs;
261 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
262
263 GraphDef fused_graph_def;
264 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
265 &fused_graph_def));
266
267 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
268 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
269 std::vector<Tensor> fused_outputs;
270 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
271
272 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
273
274 for (const NodeDef& node : fused_graph_def.node()) {
275 EXPECT_NE("FusedBatchNorm", node.op());
276 }
277 }
278
TestFoldFusedBatchNormsAfterDepthwiseConv2dNative()279 void TestFoldFusedBatchNormsAfterDepthwiseConv2dNative() {
280 auto root = tensorflow::Scope::NewRootScope();
281 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
282
283 Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
284 test::FillValues<float>(
285 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
286 -5.0f, -3.0f, -6.0f});
287 Output input_op =
288 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
289
290 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
291 test::FillValues<float>(&weights_data,
292 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
293 Output weights_op =
294 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
295
296 Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
297 weights_op, {1, 1, 1, 1}, "VALID");
298
299 Tensor mean_data(DT_FLOAT, TensorShape({4}));
300 test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
301 Output mean_op =
302 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
303
304 Tensor variance_data(DT_FLOAT, TensorShape({4}));
305 test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
306 Output variance_op = Const(root.WithOpName("variance_op"),
307 Input::Initializer(variance_data));
308
309 Tensor beta_data(DT_FLOAT, TensorShape({4}));
310 test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
311 Output beta_op =
312 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
313
314 Tensor gamma_data(DT_FLOAT, TensorShape({4}));
315 test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
316 Output gamma_op =
317 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
318
319 GraphDef original_graph_def;
320 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
321
322 NodeDef batch_norm_node;
323 batch_norm_node.set_op("FusedBatchNorm");
324 batch_norm_node.set_name("output");
325 AddNodeInput("conv_op", &batch_norm_node);
326 AddNodeInput("gamma_op", &batch_norm_node);
327 AddNodeInput("beta_op", &batch_norm_node);
328 AddNodeInput("mean_op", &batch_norm_node);
329 AddNodeInput("variance_op", &batch_norm_node);
330 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
331 SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
332 SetNodeAttr("is_training", false, &batch_norm_node);
333 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
334
335 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
336 TF_ASSERT_OK(original_session->Create(original_graph_def));
337 std::vector<Tensor> original_outputs;
338 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
339
340 GraphDef fused_graph_def;
341 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
342 &fused_graph_def));
343
344 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
345 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
346 std::vector<Tensor> fused_outputs;
347 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
348
349 test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
350 /*rtol=*/2e-5);
351
352 for (const NodeDef& node : fused_graph_def.node()) {
353 EXPECT_NE("FusedBatchNorm", node.op());
354 }
355 }
356
TestFoldFusedBatchNormsWithConcat(const bool split)357 void TestFoldFusedBatchNormsWithConcat(const bool split) {
358 auto root = tensorflow::Scope::NewRootScope();
359 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
360
361 // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
362 auto input_shape =
363 split ? TensorShape({1, 1, 6, 2}) : TensorShape({1, 1, 12, 1});
364 Tensor input_data(DT_FLOAT, input_shape);
365 test::FillValues<float>(
366 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
367 -5.0f, -3.0f, -6.0f});
368
369 Output input0_op =
370 Const(root.WithOpName("input_op0"), Input::Initializer(input_data));
371 // If split is true, concat two inputs at dim=3; otherwise, concat at dim 2.
372 // The final output shape of concat is always {1, 2, 2, 2}.
373 auto weight_shape =
374 split ? TensorShape({1, 2, 2, 1}) : TensorShape({1, 2, 1, 2});
375 Tensor weights0_data(DT_FLOAT, weight_shape);
376 test::FillValues<float>(&weights0_data, {1.0f, 2.0f, 3.0f, 4.0f});
377 Output weights0_op = Const(root.WithOpName("weights1_op"),
378 Input::Initializer(weights0_data));
379 Output conv0_op = Conv2D(root.WithOpName("conv1_op"), input0_op,
380 weights0_op, {1, 1, 1, 1}, "VALID");
381
382 Output input1_op =
383 Const(root.WithOpName("input1_op"), Input::Initializer(input_data));
384 Tensor weights1_data(DT_FLOAT, weight_shape);
385 test::FillValues<float>(&weights1_data, {1.0f, 2.0f, 3.0f, 4.0f});
386 Output weights1_op = Const(root.WithOpName("weights1_op"),
387 Input::Initializer(weights1_data));
388 Output conv1_op = Conv2D(root.WithOpName("conv1_op"), input1_op,
389 weights1_op, {1, 1, 1, 1}, "VALID");
390
391 Tensor shape_tensor(DT_INT32, TensorShape({}));
392 // Concat at dim 3 if split; otherwise, concat at dim 2.
393 int32_t concat_axis = split ? 3 : 2;
394 test::FillValues<int32>(&shape_tensor, {concat_axis});
395 Output shape_op =
396 Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor));
397 Output concat_op =
398 Concat(root.WithOpName("concat_op"), {conv0_op, conv1_op}, shape_op);
399
400 Tensor mean_data(DT_FLOAT, TensorShape({2}));
401 test::FillValues<float>(&mean_data, {10.0f, 20.0f});
402 Output mean_op =
403 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
404
405 Tensor variance_data(DT_FLOAT, TensorShape({2}));
406 test::FillValues<float>(&variance_data, {0.25f, 0.5f});
407 Output variance_op = Const(root.WithOpName("variance_op"),
408 Input::Initializer(variance_data));
409
410 Tensor beta_data(DT_FLOAT, TensorShape({2}));
411 test::FillValues<float>(&beta_data, {0.1f, 0.6f});
412 Output beta_op =
413 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
414
415 Tensor gamma_data(DT_FLOAT, TensorShape({2}));
416 test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
417 Output gamma_op =
418 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
419
420 GraphDef original_graph_def;
421 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
422
423 NodeDef batch_norm_node;
424 batch_norm_node.set_op("FusedBatchNorm");
425 batch_norm_node.set_name("output");
426 AddNodeInput("concat_op", &batch_norm_node);
427 AddNodeInput("gamma_op", &batch_norm_node);
428 AddNodeInput("beta_op", &batch_norm_node);
429 AddNodeInput("mean_op", &batch_norm_node);
430 AddNodeInput("variance_op", &batch_norm_node);
431 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
432 SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
433 SetNodeAttr("is_training", false, &batch_norm_node);
434 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
435
436 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
437 TF_ASSERT_OK(original_session->Create(original_graph_def));
438 std::vector<Tensor> original_outputs;
439 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
440
441 GraphDef fused_graph_def;
442 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
443 &fused_graph_def));
444
445 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
446 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
447 std::vector<Tensor> fused_outputs;
448 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
449
450 test::ExpectClose(original_outputs[0], fused_outputs[0]);
451
452 for (const NodeDef& node : fused_graph_def.node()) {
453 EXPECT_NE("FusedBatchNorm", node.op());
454 }
455 }
456 };
457
TestFoldFusedBatchNormsWithBatchToSpace()458 void TestFoldFusedBatchNormsWithBatchToSpace() {
459 auto root = tensorflow::Scope::NewRootScope();
460 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
461
462 Tensor input_data(DT_FLOAT, TensorShape({2, 1, 3, 2}));
463 test::FillValues<float>(
464 &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
465 -5.0f, -3.0f, -6.0f});
466 Output input_op =
467 Const(root.WithOpName("input_op"), Input::Initializer(input_data));
468
469 Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
470 test::FillValues<float>(&weights_data,
471 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
472 Output weights_op =
473 Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
474
475 Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
476 {1, 1, 1, 1}, "VALID");
477
478 Tensor block_shape_data(DT_INT32, TensorShape({2}));
479 test::FillValues<int32>(&block_shape_data, {1, 2});
480 Output block_shape_op = Const(root.WithOpName("block_shape_op"),
481 Input::Initializer(block_shape_data));
482
483 Tensor crops_data(DT_INT32, TensorShape({2, 2}));
484 test::FillValues<int32>(&crops_data, {0, 0, 0, 1});
485 Output crops_op =
486 Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
487
488 Output batch_to_space_op =
489 BatchToSpaceND(root.WithOpName("batch_to_space_op"), conv_op,
490 block_shape_op, crops_data);
491
492 Tensor mean_data(DT_FLOAT, TensorShape({2}));
493 test::FillValues<float>(&mean_data, {10.0f, 20.0f});
494 Output mean_op =
495 Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
496
497 Tensor variance_data(DT_FLOAT, TensorShape({2}));
498 test::FillValues<float>(&variance_data, {0.25f, 0.5f});
499 Output variance_op =
500 Const(root.WithOpName("variance_op"), Input::Initializer(variance_data));
501
502 Tensor beta_data(DT_FLOAT, TensorShape({2}));
503 test::FillValues<float>(&beta_data, {0.1f, 0.6f});
504 Output beta_op =
505 Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
506
507 Tensor gamma_data(DT_FLOAT, TensorShape({2}));
508 test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
509 Output gamma_op =
510 Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
511
512 GraphDef original_graph_def;
513 TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
514
515 NodeDef batch_norm_node;
516 batch_norm_node.set_op("FusedBatchNorm");
517 batch_norm_node.set_name("output");
518 AddNodeInput("batch_to_space_op", &batch_norm_node);
519 AddNodeInput("gamma_op", &batch_norm_node);
520 AddNodeInput("beta_op", &batch_norm_node);
521 AddNodeInput("mean_op", &batch_norm_node);
522 AddNodeInput("variance_op", &batch_norm_node);
523 SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
524 SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
525 SetNodeAttr("is_training", false, &batch_norm_node);
526 *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
527
528 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
529 TF_ASSERT_OK(original_session->Create(original_graph_def));
530 std::vector<Tensor> original_outputs;
531 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
532
533 GraphDef fused_graph_def;
534 TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
535 &fused_graph_def));
536
537 std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
538 TF_ASSERT_OK(fused_session->Create(fused_graph_def));
539 std::vector<Tensor> fused_outputs;
540 TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
541
542 test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
543
544 for (const NodeDef& node : fused_graph_def.node()) {
545 EXPECT_NE("FusedBatchNormWithBatchToSpace", node.op());
546 }
547 }
548
TEST_F(FoldOldBatchNormsTest,TestFoldOldBatchNorms)549 TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
550 TestFoldOldBatchNorms();
551 }
552
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNorms)553 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) {
554 TestFoldFusedBatchNorms();
555 }
556
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsWithConcat)557 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithConcat) {
558 // Test axis is not 3, so all weights and offsets are fused to each of inputs
559 // of conv2d.
560 TestFoldFusedBatchNormsWithConcat(/*split=*/true);
561 // Test axis = 3, BatchNorm weights and offsets will be split before fused
562 // with conv2d weights.
563 TestFoldFusedBatchNormsWithConcat(/*split=*/false);
564 }
565
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsWithBatchToSpace)566 TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsWithBatchToSpace) {
567 TestFoldFusedBatchNormsWithBatchToSpace();
568 }
569
TEST_F(FoldOldBatchNormsTest,TestFoldOldBatchNormsAfterDepthwiseConv2dNative)570 TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNormsAfterDepthwiseConv2dNative) {
571 TestFoldOldBatchNormsAfterDepthwiseConv2dNative();
572 }
573
TEST_F(FoldOldBatchNormsTest,TestFoldFusedBatchNormsAfterDepthwiseConv2dNative)574 TEST_F(FoldOldBatchNormsTest,
575 TestFoldFusedBatchNormsAfterDepthwiseConv2dNative) {
576 TestFoldFusedBatchNormsAfterDepthwiseConv2dNative();
577 }
578
579 } // namespace graph_transforms
580 } // namespace tensorflow
581