1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef INTEL_MKL
17
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/image_ops.h"
20 #include "tensorflow/cc/ops/nn_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/fake_input.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/graph/mkl_graph_util.h"
28 #include "tensorflow/core/kernels/conv_ops_gpu.h"
29 #include "tensorflow/core/kernels/ops_testutil.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/util/util.h"
36
37 namespace tensorflow {
38
39 // Helper class for converting MKL tensors to TF tensors and comparing to
40 // expected values
41
42 static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
43 static const TensorShape dummy_shape({8});
44
45 using GraphRunner = std::function<void(
46 const Tensor& input, const Tensor& scale, const Tensor& offset,
47 const Tensor& mean, const Tensor& variance,
48 const float exponential_avg_factor, const bool is_training, Tensor* output,
49 Tensor* batch_mean, Tensor* batch_var)>;
50
51 using GraphRunnerGrad = std::function<void(
52 const Tensor& input, const Tensor& filter, const Tensor& y_backprop,
53 const Tensor& scale, const Tensor& mean, const Tensor& variance,
54 const Tensor& res_sp3, Tensor* output, Tensor* scale_backprop,
55 Tensor* offset_backprop, bool disable_grappler_opts)>;
56
57 template <typename T>
58 class CommonTestUtilities : public OpsTestBase {
59 public:
PerformConversion(DataType dtype,const Tensor & tensor,const Tensor & mkl_meta_tensor,Tensor * output)60 void PerformConversion(DataType dtype, const Tensor& tensor,
61 const Tensor& mkl_meta_tensor, Tensor* output) {
62 // Create an MKL to TF conversion node and execute it
63 TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
64 .Input(FakeInput(dtype)) // Input
65 .Input(FakeInput(DT_UINT8)) // Mkl second tensor
66 .Attr("T", dtype)
67 .Attr("_kernel", "MklLayoutDependentOp")
68 .Finalize(node_def()));
69 TF_EXPECT_OK(InitOp());
70 AddInputFromArray<T>(tensor.shape(), tensor.flat<T>());
71 AddInputFromArray<uint8>(mkl_meta_tensor.shape(),
72 mkl_meta_tensor.flat<uint8>());
73 TF_ASSERT_OK(RunOpKernel());
74
75 *output = *GetOutput(0);
76 }
77
TestBody()78 void TestBody() {}
79
VerifyTensorsClose(const float exponential_avg_factor,const bool is_training,const GraphRunner & run,const GraphRunner & run_mkl)80 static void VerifyTensorsClose(const float exponential_avg_factor,
81 const bool is_training, const GraphRunner& run,
82 const GraphRunner& run_mkl) {
83 int batch = 1;
84 int height = 10;
85 int width = 10;
86 int depth = 3;
87 DataType dtype = DataTypeToEnum<T>::v();
88
89 Tensor input(dtype, {batch, height, width, depth});
90 input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
91 Tensor scale(dtype, {depth});
92 scale.flat<T>() = scale.flat<T>().template setRandom<random_gen_>();
93 Tensor offset(dtype, {depth});
94 offset.flat<T>() = offset.flat<T>().template setRandom<random_gen_>();
95
96 if (is_training && (exponential_avg_factor == 1.0)) {
97 depth = 0;
98 }
99 Tensor mean(dtype, {depth});
100 mean.flat<T>() = mean.flat<T>().template setRandom<random_gen_>();
101 Tensor variance(dtype, {depth});
102 variance.flat<T>() =
103 variance.flat<T>().template setRandom<random_gen_>().abs();
104
105 Tensor output;
106 Tensor batch_mean;
107 Tensor batch_var;
108 Tensor mkl_output;
109 Tensor mkl_batch_mean;
110 Tensor mkl_batch_var;
111
112 run(input, scale, offset, mean, variance, exponential_avg_factor,
113 is_training, &output, &batch_mean, &batch_var);
114 run_mkl(input, scale, offset, mean, variance, exponential_avg_factor,
115 is_training, &mkl_output, &mkl_batch_mean, &mkl_batch_var);
116
117 ASSERT_EQ(output.dtype(), mkl_output.dtype());
118 ASSERT_EQ(output.shape(), mkl_output.shape());
119 ASSERT_EQ(batch_mean.dtype(), mkl_batch_mean.dtype());
120 ASSERT_EQ(batch_mean.shape(), mkl_batch_mean.shape());
121 ASSERT_EQ(batch_var.dtype(), mkl_batch_var.dtype());
122 ASSERT_EQ(batch_var.shape(), mkl_batch_var.shape());
123
124 test::ExpectClose(output, mkl_output, 1e-5);
125 test::ExpectClose(batch_mean, mkl_batch_mean, 1e-5);
126 test::ExpectClose(batch_var, mkl_batch_var, 1e-5);
127 }
128
VerifyTensorsCloseForGrad(const float epsilon,const GraphRunnerGrad & run,const GraphRunnerGrad & run_mkl)129 static void VerifyTensorsCloseForGrad(const float epsilon,
130 const GraphRunnerGrad& run,
131 const GraphRunnerGrad& run_mkl) {
132 int batch = 2;
133 int height = 8;
134 int width = 8;
135 int depth = 1;
136 int filter_height = 3;
137 int filter_width = 3;
138 int in_channels = 1;
139 int out_channels = 6;
140 DataType dtype = DataTypeToEnum<T>::v();
141
142 Tensor input(dtype, {batch, height, width, depth});
143 input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
144 Tensor filter(dtype,
145 {filter_height, filter_width, in_channels, out_channels});
146 filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
147
148 Tensor y_backprop(dtype, {batch, height, width, out_channels});
149 y_backprop.flat<T>() =
150 y_backprop.flat<T>().template setRandom<random_gen_>();
151 Tensor scale(dtype, {out_channels});
152 scale.flat<T>() = scale.flat<T>().template setRandom<random_gen_>();
153 Tensor mean(dtype, {out_channels});
154 mean.flat<T>() = mean.flat<T>().template setRandom<random_gen_>();
155 Tensor variance(dtype, {out_channels});
156 variance.flat<T>() =
157 variance.flat<T>().template setRandom<random_gen_>().abs();
158 Tensor res_sp3(dtype, {out_channels});
159 res_sp3.flat<T>() =
160 res_sp3.flat<T>().template setRandom<random_gen_>().abs();
161
162 Tensor output;
163 Tensor scale_backprop;
164 Tensor offset_backprop;
165 Tensor mkl_output;
166 Tensor mkl_scale_backprop;
167 Tensor mkl_offset_backprop;
168
169 run(input, filter, y_backprop, scale, mean, variance, res_sp3, &output,
170 &scale_backprop, &offset_backprop, epsilon);
171
172 run_mkl(input, filter, y_backprop, scale, mean, variance, res_sp3,
173 &mkl_output, &mkl_scale_backprop, &mkl_offset_backprop, epsilon);
174
175 ASSERT_EQ(output.dtype(), mkl_output.dtype());
176 ASSERT_EQ(output.shape(), mkl_output.shape());
177 ASSERT_EQ(scale_backprop.dtype(), mkl_scale_backprop.dtype());
178 ASSERT_EQ(scale_backprop.shape(), mkl_scale_backprop.shape());
179 ASSERT_EQ(offset_backprop.dtype(), mkl_offset_backprop.dtype());
180 ASSERT_EQ(offset_backprop.shape(), mkl_offset_backprop.shape());
181
182 test::ExpectClose(output, mkl_output, 1e-5);
183 test::ExpectClose(scale_backprop, mkl_scale_backprop, 1e-5);
184 test::ExpectClose(offset_backprop, mkl_offset_backprop, 1e-5);
185 }
186
187 private:
188 using random_gen_ = Eigen::internal::NormalRandomGenerator<T>;
189 };
190
191 template <typename T>
192 class Conv2DOpTest : public OpsTestBase {
TestBody()193 void TestBody() {}
194
195 public:
RunConv2D(const Tensor & input,const Tensor & filter,Tensor * output,Tensor * meta_output)196 void RunConv2D(const Tensor& input, const Tensor& filter, Tensor* output,
197 Tensor* meta_output) {
198 DataType dtype = DataTypeToEnum<T>::v();
199
200 TF_EXPECT_OK(NodeDefBuilder("MklConv2D", "_MklConv2D")
201 .Input(FakeInput(dtype))
202 .Input(FakeInput(dtype))
203 .Input(FakeInput(DT_UINT8))
204 .Input(FakeInput(DT_UINT8))
205 .Attr("strides", {1, 1, 1, 1})
206 .Attr("padding", "SAME")
207 .Attr("data_format", "NHWC")
208 .Attr("_kernel", "MklLayoutDependentOp")
209 .Finalize(node_def()));
210 TF_EXPECT_OK(InitOp());
211 AddInputFromArray<T>(input.shape(), input.flat<T>());
212 AddInputFromArray<T>(filter.shape(), filter.flat<T>());
213 for (int i = 0; i < 2; ++i)
214 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
215 TF_ASSERT_OK(RunOpKernel());
216
217 *output = *GetOutput(0);
218 *meta_output = *GetOutput(2);
219 }
220 };
221
222 template <typename T>
223 class FusedBatchNormOpTest : public OpsTestBase {
224 protected:
VerifyFusedBatchNorm(const float exponential_avg_factor,const bool is_training)225 void VerifyFusedBatchNorm(const float exponential_avg_factor,
226 const bool is_training) {
227 const GraphRunner run = [this](const Tensor& input, const Tensor& scale,
228 const Tensor& offset, const Tensor& mean,
229 const Tensor& variance,
230 const float exponential_avg_factor,
231 const bool is_training, Tensor* output,
232 Tensor* batch_mean, Tensor* batch_var) {
233 auto root = tensorflow::Scope::NewRootScope();
234 auto input_op =
235 ops::Const(root.WithOpName("input"), Input::Initializer(input));
236 auto scale_op =
237 ops::Const(root.WithOpName("scale"), Input::Initializer(scale));
238 auto offset_op =
239 ops::Const(root.WithOpName("offset"), Input::Initializer(offset));
240 auto mean_op =
241 ops::Const(root.WithOpName("mean"), Input::Initializer(mean));
242 auto var_op =
243 ops::Const(root.WithOpName("variance"), Input::Initializer(variance));
244
245 ops::FusedBatchNorm::Attrs attr;
246 attr = attr.IsTraining(is_training);
247 attr = attr.ExponentialAvgFactor(exponential_avg_factor);
248 attr = attr.Epsilon(0.001);
249 auto bn = ops::FusedBatchNorm(root.WithOpName("FusedBatchNorm"), input_op,
250 scale_op, offset_op, mean_op, var_op, attr);
251 auto y = ops::Identity(root.WithOpName("y"), bn.y);
252 auto y_batch_mean =
253 ops::Identity(root.WithOpName("y_batch_mean"), bn.batch_mean);
254 auto y_batch_var =
255 ops::Identity(root.WithOpName("y_batch_var"), bn.batch_variance);
256
257 tensorflow::GraphDef graph;
258 TF_ASSERT_OK(root.ToGraphDef(&graph));
259
260 std::unique_ptr<tensorflow::Session> session(
261 tensorflow::NewSession(tensorflow::SessionOptions()));
262 TF_ASSERT_OK(session->Create(graph));
263
264 std::vector<Tensor> output_tensors;
265 TF_ASSERT_OK(session->Run({}, {"y", "y_batch_mean", "y_batch_var"}, {},
266 &output_tensors));
267
268 *output = output_tensors[0];
269 *batch_mean = output_tensors[1];
270 *batch_var = output_tensors[2];
271 };
272
273 const GraphRunner run_mkl = [this](const Tensor& input, const Tensor& scale,
274 const Tensor& offset, const Tensor& mean,
275 const Tensor& variance,
276 const float exponential_avg_factor,
277 const bool is_training, Tensor* output,
278 Tensor* batch_mean, Tensor* batch_var) {
279 DataType dtype = DataTypeToEnum<T>::v();
280 if (!NativeFormatEnabled()) {
281 TF_EXPECT_OK(NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNorm")
282 .Input(FakeInput(dtype))
283 .Input(FakeInput(DT_FLOAT))
284 .Input(FakeInput(DT_FLOAT))
285 .Input(FakeInput(DT_FLOAT))
286 .Input(FakeInput(DT_FLOAT))
287 .Input(FakeInput(DT_UINT8))
288 .Input(FakeInput(DT_UINT8))
289 .Input(FakeInput(DT_UINT8))
290 .Input(FakeInput(DT_UINT8))
291 .Input(FakeInput(DT_UINT8))
292 .Attr("exponential_avg_factor", exponential_avg_factor)
293 .Attr("epsilon", 0.001)
294 .Attr("is_training", is_training)
295 .Attr("_kernel", "MklLayoutDependentOp")
296 .Finalize(node_def()));
297 } else {
298 TF_EXPECT_OK(NodeDefBuilder("MklNativeFusedBatchNorm",
299 "_MklNativeFusedBatchNorm")
300 .Input(FakeInput(dtype))
301 .Input(FakeInput(DT_FLOAT))
302 .Input(FakeInput(DT_FLOAT))
303 .Input(FakeInput(DT_FLOAT))
304 .Input(FakeInput(DT_FLOAT))
305 .Attr("exponential_avg_factor", exponential_avg_factor)
306 .Attr("epsilon", 0.001)
307 .Attr("is_training", is_training)
308 .Attr("_kernel", "MklNameChangeOp")
309 .Finalize(node_def()));
310 }
311 TF_EXPECT_OK(InitOp());
312
313 AddInputFromArray<T>(input.shape(), input.flat<T>());
314 AddInputFromArray<float>(scale.shape(), scale.flat<float>());
315 AddInputFromArray<float>(offset.shape(), offset.flat<float>());
316 AddInputFromArray<float>(mean.shape(), mean.flat<float>());
317 AddInputFromArray<float>(variance.shape(), variance.flat<float>());
318 if (!NativeFormatEnabled()) {
319 for (int i = 0; i < 5; ++i)
320 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
321 }
322 TF_ASSERT_OK(RunOpKernel());
323
324 if (!NativeFormatEnabled()) {
325 CommonTestUtilities<T> test_util;
326 test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5),
327 output);
328
329 CommonTestUtilities<T> test_util_mean;
330 test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6),
331 batch_mean);
332
333 CommonTestUtilities<T> test_util_var;
334 test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7),
335 batch_var);
336 } else {
337 *output = *GetOutput(0);
338 *batch_mean = *GetOutput(1);
339 *batch_var = *GetOutput(2);
340 }
341 };
342
343 CommonTestUtilities<T>::VerifyTensorsClose(exponential_avg_factor,
344 is_training, run, run_mkl);
345 }
346
VerifyFusedBatchNormGradWithConv2D(const float epsilon)347 void VerifyFusedBatchNormGradWithConv2D(const float epsilon) {
348 #ifdef ENABLE_MKL
349 // This test only runs with MKL blocked format.
350 const GraphRunnerGrad run =
351 [this](const Tensor& input, const Tensor& filter,
352 const Tensor& y_backprop, const Tensor& scale,
353 const Tensor& mean, const Tensor& variance,
354 const Tensor& res_sp3, Tensor* x_backprop_tensor,
355 Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
356 const float epsilon) {
357 auto root = tensorflow::Scope::NewRootScope();
358
359 auto input_op =
360 ops::Const(root.WithOpName("input"), Input::Initializer(input));
361 auto filter_op =
362 ops::Const(root.WithOpName("filter"), Input::Initializer(filter));
363 ops::Conv2D::Attrs conv_attr;
364 conv_attr = conv_attr.DataFormat("NHWC");
365 auto conv = ops::Conv2D(root.WithOpName("Conv"), input_op, filter_op,
366 {1, 1, 1, 1}, "SAME", conv_attr);
367 // -------------------------------------------------------------
368 auto y_backprop_op = ops::Const(root.WithOpName("y_backprop"),
369 Input::Initializer(y_backprop));
370 auto scale_op =
371 ops::Const(root.WithOpName("scale"), Input::Initializer(scale));
372 auto mean_op =
373 ops::Const(root.WithOpName("mean"), Input::Initializer(mean));
374 auto var_op = ops::Const(root.WithOpName("variance"),
375 Input::Initializer(variance));
376 auto res_sp3_op = ops::Const(root.WithOpName("reserve_space_3"),
377 Input::Initializer(res_sp3));
378 ops::FusedBatchNormGradV3::Attrs bn_attr;
379 bn_attr = bn_attr.IsTraining(true);
380 bn_attr = bn_attr.Epsilon(epsilon);
381 bn_attr = bn_attr.DataFormat("NHWC");
382 auto bn = ops::FusedBatchNormGradV3(
383 root.WithOpName("FusedBatchNormGrad"), y_backprop_op, conv,
384 scale_op, mean_op, var_op, res_sp3_op, bn_attr);
385
386 auto x_backprop =
387 ops::Identity(root.WithOpName("x_backprop"), bn.x_backprop);
388 auto scale_backprop = ops::Identity(root.WithOpName("scale_backprop"),
389 bn.scale_backprop);
390 auto offset_backprop = ops::Identity(
391 root.WithOpName("offset_backprop"), bn.offset_backprop);
392
393 tensorflow::GraphDef graph;
394 TF_ASSERT_OK(root.ToGraphDef(&graph));
395
396 tensorflow::SessionOptions session_options;
397 std::unique_ptr<tensorflow::Session> session(
398 tensorflow::NewSession(session_options));
399 TF_ASSERT_OK(session->Create(graph));
400
401 std::vector<Tensor> output_tensors;
402 TF_ASSERT_OK(session->Run(
403 {}, {"x_backprop", "scale_backprop", "offset_backprop"}, {},
404 &output_tensors));
405
406 *x_backprop_tensor = output_tensors[0];
407 *scale_backprop_tensor = output_tensors[1];
408 *offset_backprop_tensor = output_tensors[2];
409 };
410
411 const GraphRunnerGrad run_mkl =
412 [this](const Tensor& input, const Tensor& filter,
413 const Tensor& y_backprop, const Tensor& scale,
414 const Tensor& mean, const Tensor& variance,
415 const Tensor& res_sp3, Tensor* x_backprop_tensor,
416 Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
417 const float epsilon) {
418 Tensor conv2d_output, conv2d_meta_output;
419 Conv2DOpTest<T> conv2d_test;
420 conv2d_test.RunConv2D(input, filter, &conv2d_output,
421 &conv2d_meta_output);
422
423 DataType dtype = DataTypeToEnum<T>::v();
424 TF_EXPECT_OK(
425 NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNormGradV3")
426 .Input(FakeInput(dtype))
427 .Input(FakeInput(dtype))
428 .Input(FakeInput(DT_FLOAT))
429 .Input(FakeInput(DT_FLOAT))
430 .Input(FakeInput(DT_FLOAT))
431 .Input(FakeInput(DT_FLOAT))
432 .Input(FakeInput(DT_UINT8))
433 .Input(FakeInput(DT_UINT8))
434 .Input(FakeInput(DT_UINT8))
435 .Input(FakeInput(DT_UINT8))
436 .Input(FakeInput(DT_UINT8))
437 .Input(FakeInput(DT_UINT8))
438 .Attr("epsilon", epsilon)
439 .Attr("is_training", true)
440 .Attr("data_format", "NHWC")
441 .Attr("_kernel", "MklLayoutDependentOp")
442 .Finalize(node_def()));
443 TF_EXPECT_OK(InitOp());
444
445 AddInputFromArray<T>(y_backprop.shape(), y_backprop.flat<T>());
446 AddInputFromArray<T>(conv2d_output.shape(), conv2d_output.flat<T>());
447 AddInputFromArray<float>(scale.shape(), scale.flat<float>());
448 AddInputFromArray<float>(mean.shape(), mean.flat<float>());
449 AddInputFromArray<float>(variance.shape(), variance.flat<float>());
450 AddInputFromArray<float>(res_sp3.shape(), res_sp3.flat<float>());
451 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
452 AddInputFromArray<uint8>(conv2d_meta_output.shape(),
453 conv2d_meta_output.flat<uint8>());
454 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
455 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
456 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
457 AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
458 TF_ASSERT_OK(RunOpKernel());
459
460 CommonTestUtilities<T> test_util;
461 test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5),
462 x_backprop_tensor);
463
464 CommonTestUtilities<T> test_util_mean;
465 test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6),
466 scale_backprop_tensor);
467
468 CommonTestUtilities<T> test_util_var;
469 test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7),
470 offset_backprop_tensor);
471 };
472
473 CommonTestUtilities<T>::VerifyTensorsCloseForGrad(epsilon, run, run_mkl);
474 #endif // ENABLE_MKL
475 }
476 };
477
478 TYPED_TEST_SUITE_P(FusedBatchNormOpTest);
479
TYPED_TEST_P(FusedBatchNormOpTest,Training)480 TYPED_TEST_P(FusedBatchNormOpTest, Training) {
481 const float exponential_avg_factor = 1.0;
482 const bool is_training = true;
483 this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
484 }
485
TYPED_TEST_P(FusedBatchNormOpTest,TrainingRunningMean)486 TYPED_TEST_P(FusedBatchNormOpTest, TrainingRunningMean) {
487 const float exponential_avg_factor = 0.5;
488 const bool is_training = true;
489 this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
490 }
491
TYPED_TEST_P(FusedBatchNormOpTest,Inference)492 TYPED_TEST_P(FusedBatchNormOpTest, Inference) {
493 const float exponential_avg_factor = 1.0;
494 const bool is_training = false;
495 this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
496 }
497
TYPED_TEST_P(FusedBatchNormOpTest,InferenceIgnoreAvgFactor)498 TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
499 const float exponential_avg_factor = 0.5;
500 const bool is_training = false;
501 this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
502 }
503
TYPED_TEST_P(FusedBatchNormOpTest,FusedBatchNormGradV3)504 TYPED_TEST_P(FusedBatchNormOpTest, FusedBatchNormGradV3) {
505 const float epsilon = 0.001;
506 this->VerifyFusedBatchNormGradWithConv2D(epsilon);
507 }
508
509 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
510 Inference, InferenceIgnoreAvgFactor,
511 FusedBatchNormGradV3);
512
513 using FusedBatchNormDataTypes = ::testing::Types<float>;
514 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest,
515 FusedBatchNormDataTypes);
516
517 } // namespace tensorflow
518
519 #endif // INTEL_MKL
520