• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/image_ops.h"
20 #include "tensorflow/cc/ops/nn_ops.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/fake_input.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/graph/mkl_graph_util.h"
28 #include "tensorflow/core/kernels/conv_ops_gpu.h"
29 #include "tensorflow/core/kernels/ops_testutil.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/public/session.h"
35 #include "tensorflow/core/util/util.h"
36 
37 namespace tensorflow {
38 
39 // Helper class for converting MKL tensors to TF tensors and comparing to
40 // expected values
41 
42 static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
43 static const TensorShape dummy_shape({8});
44 
45 using GraphRunner = std::function<void(
46     const Tensor& input, const Tensor& scale, const Tensor& offset,
47     const Tensor& mean, const Tensor& variance,
48     const float exponential_avg_factor, const bool is_training, Tensor* output,
49     Tensor* batch_mean, Tensor* batch_var)>;
50 
51 using GraphRunnerGrad = std::function<void(
52     const Tensor& input, const Tensor& filter, const Tensor& y_backprop,
53     const Tensor& scale, const Tensor& mean, const Tensor& variance,
54     const Tensor& res_sp3, Tensor* output, Tensor* scale_backprop,
55     Tensor* offset_backprop, bool disable_grappler_opts)>;
56 
57 template <typename T>
58 class CommonTestUtilities : public OpsTestBase {
59  public:
PerformConversion(DataType dtype,const Tensor & tensor,const Tensor & mkl_meta_tensor,Tensor * output)60   void PerformConversion(DataType dtype, const Tensor& tensor,
61                          const Tensor& mkl_meta_tensor, Tensor* output) {
62     // Create an MKL to TF conversion node and execute it
63     TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
64                      .Input(FakeInput(dtype))     // Input
65                      .Input(FakeInput(DT_UINT8))  // Mkl second tensor
66                      .Attr("T", dtype)
67                      .Attr("_kernel", "MklLayoutDependentOp")
68                      .Finalize(node_def()));
69     TF_EXPECT_OK(InitOp());
70     AddInputFromArray<T>(tensor.shape(), tensor.flat<T>());
71     AddInputFromArray<uint8>(mkl_meta_tensor.shape(),
72                              mkl_meta_tensor.flat<uint8>());
73     TF_ASSERT_OK(RunOpKernel());
74 
75     *output = *GetOutput(0);
76   }
77 
TestBody()78   void TestBody() {}
79 
VerifyTensorsClose(const float exponential_avg_factor,const bool is_training,const GraphRunner & run,const GraphRunner & run_mkl)80   static void VerifyTensorsClose(const float exponential_avg_factor,
81                                  const bool is_training, const GraphRunner& run,
82                                  const GraphRunner& run_mkl) {
83     int batch = 1;
84     int height = 10;
85     int width = 10;
86     int depth = 3;
87     DataType dtype = DataTypeToEnum<T>::v();
88 
89     Tensor input(dtype, {batch, height, width, depth});
90     input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
91     Tensor scale(dtype, {depth});
92     scale.flat<T>() = scale.flat<T>().template setRandom<random_gen_>();
93     Tensor offset(dtype, {depth});
94     offset.flat<T>() = offset.flat<T>().template setRandom<random_gen_>();
95 
96     if (is_training && (exponential_avg_factor == 1.0)) {
97       depth = 0;
98     }
99     Tensor mean(dtype, {depth});
100     mean.flat<T>() = mean.flat<T>().template setRandom<random_gen_>();
101     Tensor variance(dtype, {depth});
102     variance.flat<T>() =
103         variance.flat<T>().template setRandom<random_gen_>().abs();
104 
105     Tensor output;
106     Tensor batch_mean;
107     Tensor batch_var;
108     Tensor mkl_output;
109     Tensor mkl_batch_mean;
110     Tensor mkl_batch_var;
111 
112     run(input, scale, offset, mean, variance, exponential_avg_factor,
113         is_training, &output, &batch_mean, &batch_var);
114     run_mkl(input, scale, offset, mean, variance, exponential_avg_factor,
115             is_training, &mkl_output, &mkl_batch_mean, &mkl_batch_var);
116 
117     ASSERT_EQ(output.dtype(), mkl_output.dtype());
118     ASSERT_EQ(output.shape(), mkl_output.shape());
119     ASSERT_EQ(batch_mean.dtype(), mkl_batch_mean.dtype());
120     ASSERT_EQ(batch_mean.shape(), mkl_batch_mean.shape());
121     ASSERT_EQ(batch_var.dtype(), mkl_batch_var.dtype());
122     ASSERT_EQ(batch_var.shape(), mkl_batch_var.shape());
123 
124     test::ExpectClose(output, mkl_output, 1e-5);
125     test::ExpectClose(batch_mean, mkl_batch_mean, 1e-5);
126     test::ExpectClose(batch_var, mkl_batch_var, 1e-5);
127   }
128 
VerifyTensorsCloseForGrad(const float epsilon,const GraphRunnerGrad & run,const GraphRunnerGrad & run_mkl)129   static void VerifyTensorsCloseForGrad(const float epsilon,
130                                         const GraphRunnerGrad& run,
131                                         const GraphRunnerGrad& run_mkl) {
132     int batch = 2;
133     int height = 8;
134     int width = 8;
135     int depth = 1;
136     int filter_height = 3;
137     int filter_width = 3;
138     int in_channels = 1;
139     int out_channels = 6;
140     DataType dtype = DataTypeToEnum<T>::v();
141 
142     Tensor input(dtype, {batch, height, width, depth});
143     input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
144     Tensor filter(dtype,
145                   {filter_height, filter_width, in_channels, out_channels});
146     filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
147 
148     Tensor y_backprop(dtype, {batch, height, width, out_channels});
149     y_backprop.flat<T>() =
150         y_backprop.flat<T>().template setRandom<random_gen_>();
151     Tensor scale(dtype, {out_channels});
152     scale.flat<T>() = scale.flat<T>().template setRandom<random_gen_>();
153     Tensor mean(dtype, {out_channels});
154     mean.flat<T>() = mean.flat<T>().template setRandom<random_gen_>();
155     Tensor variance(dtype, {out_channels});
156     variance.flat<T>() =
157         variance.flat<T>().template setRandom<random_gen_>().abs();
158     Tensor res_sp3(dtype, {out_channels});
159     res_sp3.flat<T>() =
160         res_sp3.flat<T>().template setRandom<random_gen_>().abs();
161 
162     Tensor output;
163     Tensor scale_backprop;
164     Tensor offset_backprop;
165     Tensor mkl_output;
166     Tensor mkl_scale_backprop;
167     Tensor mkl_offset_backprop;
168 
169     run(input, filter, y_backprop, scale, mean, variance, res_sp3, &output,
170         &scale_backprop, &offset_backprop, epsilon);
171 
172     run_mkl(input, filter, y_backprop, scale, mean, variance, res_sp3,
173             &mkl_output, &mkl_scale_backprop, &mkl_offset_backprop, epsilon);
174 
175     ASSERT_EQ(output.dtype(), mkl_output.dtype());
176     ASSERT_EQ(output.shape(), mkl_output.shape());
177     ASSERT_EQ(scale_backprop.dtype(), mkl_scale_backprop.dtype());
178     ASSERT_EQ(scale_backprop.shape(), mkl_scale_backprop.shape());
179     ASSERT_EQ(offset_backprop.dtype(), mkl_offset_backprop.dtype());
180     ASSERT_EQ(offset_backprop.shape(), mkl_offset_backprop.shape());
181 
182     test::ExpectClose(output, mkl_output, 1e-5);
183     test::ExpectClose(scale_backprop, mkl_scale_backprop, 1e-5);
184     test::ExpectClose(offset_backprop, mkl_offset_backprop, 1e-5);
185   }
186 
187  private:
188   using random_gen_ = Eigen::internal::NormalRandomGenerator<T>;
189 };
190 
191 template <typename T>
192 class Conv2DOpTest : public OpsTestBase {
TestBody()193   void TestBody() {}
194 
195  public:
RunConv2D(const Tensor & input,const Tensor & filter,Tensor * output,Tensor * meta_output)196   void RunConv2D(const Tensor& input, const Tensor& filter, Tensor* output,
197                  Tensor* meta_output) {
198     DataType dtype = DataTypeToEnum<T>::v();
199 
200     TF_EXPECT_OK(NodeDefBuilder("MklConv2D", "_MklConv2D")
201                      .Input(FakeInput(dtype))
202                      .Input(FakeInput(dtype))
203                      .Input(FakeInput(DT_UINT8))
204                      .Input(FakeInput(DT_UINT8))
205                      .Attr("strides", {1, 1, 1, 1})
206                      .Attr("padding", "SAME")
207                      .Attr("data_format", "NHWC")
208                      .Attr("_kernel", "MklLayoutDependentOp")
209                      .Finalize(node_def()));
210     TF_EXPECT_OK(InitOp());
211     AddInputFromArray<T>(input.shape(), input.flat<T>());
212     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
213     for (int i = 0; i < 2; ++i)
214       AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
215     TF_ASSERT_OK(RunOpKernel());
216 
217     *output = *GetOutput(0);
218     *meta_output = *GetOutput(2);
219   }
220 };
221 
222 template <typename T>
223 class FusedBatchNormOpTest : public OpsTestBase {
224  protected:
VerifyFusedBatchNorm(const float exponential_avg_factor,const bool is_training)225   void VerifyFusedBatchNorm(const float exponential_avg_factor,
226                             const bool is_training) {
227     const GraphRunner run = [this](const Tensor& input, const Tensor& scale,
228                                    const Tensor& offset, const Tensor& mean,
229                                    const Tensor& variance,
230                                    const float exponential_avg_factor,
231                                    const bool is_training, Tensor* output,
232                                    Tensor* batch_mean, Tensor* batch_var) {
233       auto root = tensorflow::Scope::NewRootScope();
234       auto input_op =
235           ops::Const(root.WithOpName("input"), Input::Initializer(input));
236       auto scale_op =
237           ops::Const(root.WithOpName("scale"), Input::Initializer(scale));
238       auto offset_op =
239           ops::Const(root.WithOpName("offset"), Input::Initializer(offset));
240       auto mean_op =
241           ops::Const(root.WithOpName("mean"), Input::Initializer(mean));
242       auto var_op =
243           ops::Const(root.WithOpName("variance"), Input::Initializer(variance));
244 
245       ops::FusedBatchNorm::Attrs attr;
246       attr = attr.IsTraining(is_training);
247       attr = attr.ExponentialAvgFactor(exponential_avg_factor);
248       attr = attr.Epsilon(0.001);
249       auto bn = ops::FusedBatchNorm(root.WithOpName("FusedBatchNorm"), input_op,
250                                     scale_op, offset_op, mean_op, var_op, attr);
251       auto y = ops::Identity(root.WithOpName("y"), bn.y);
252       auto y_batch_mean =
253           ops::Identity(root.WithOpName("y_batch_mean"), bn.batch_mean);
254       auto y_batch_var =
255           ops::Identity(root.WithOpName("y_batch_var"), bn.batch_variance);
256 
257       tensorflow::GraphDef graph;
258       TF_ASSERT_OK(root.ToGraphDef(&graph));
259 
260       std::unique_ptr<tensorflow::Session> session(
261           tensorflow::NewSession(tensorflow::SessionOptions()));
262       TF_ASSERT_OK(session->Create(graph));
263 
264       std::vector<Tensor> output_tensors;
265       TF_ASSERT_OK(session->Run({}, {"y", "y_batch_mean", "y_batch_var"}, {},
266                                 &output_tensors));
267 
268       *output = output_tensors[0];
269       *batch_mean = output_tensors[1];
270       *batch_var = output_tensors[2];
271     };
272 
273     const GraphRunner run_mkl = [this](const Tensor& input, const Tensor& scale,
274                                        const Tensor& offset, const Tensor& mean,
275                                        const Tensor& variance,
276                                        const float exponential_avg_factor,
277                                        const bool is_training, Tensor* output,
278                                        Tensor* batch_mean, Tensor* batch_var) {
279       DataType dtype = DataTypeToEnum<T>::v();
280       if (!NativeFormatEnabled()) {
281         TF_EXPECT_OK(NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNorm")
282                          .Input(FakeInput(dtype))
283                          .Input(FakeInput(DT_FLOAT))
284                          .Input(FakeInput(DT_FLOAT))
285                          .Input(FakeInput(DT_FLOAT))
286                          .Input(FakeInput(DT_FLOAT))
287                          .Input(FakeInput(DT_UINT8))
288                          .Input(FakeInput(DT_UINT8))
289                          .Input(FakeInput(DT_UINT8))
290                          .Input(FakeInput(DT_UINT8))
291                          .Input(FakeInput(DT_UINT8))
292                          .Attr("exponential_avg_factor", exponential_avg_factor)
293                          .Attr("epsilon", 0.001)
294                          .Attr("is_training", is_training)
295                          .Attr("_kernel", "MklLayoutDependentOp")
296                          .Finalize(node_def()));
297       } else {
298         TF_EXPECT_OK(NodeDefBuilder("MklNativeFusedBatchNorm",
299                                     "_MklNativeFusedBatchNorm")
300                          .Input(FakeInput(dtype))
301                          .Input(FakeInput(DT_FLOAT))
302                          .Input(FakeInput(DT_FLOAT))
303                          .Input(FakeInput(DT_FLOAT))
304                          .Input(FakeInput(DT_FLOAT))
305                          .Attr("exponential_avg_factor", exponential_avg_factor)
306                          .Attr("epsilon", 0.001)
307                          .Attr("is_training", is_training)
308                          .Attr("_kernel", "MklNameChangeOp")
309                          .Finalize(node_def()));
310       }
311       TF_EXPECT_OK(InitOp());
312 
313       AddInputFromArray<T>(input.shape(), input.flat<T>());
314       AddInputFromArray<float>(scale.shape(), scale.flat<float>());
315       AddInputFromArray<float>(offset.shape(), offset.flat<float>());
316       AddInputFromArray<float>(mean.shape(), mean.flat<float>());
317       AddInputFromArray<float>(variance.shape(), variance.flat<float>());
318       if (!NativeFormatEnabled()) {
319         for (int i = 0; i < 5; ++i)
320           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
321       }
322       TF_ASSERT_OK(RunOpKernel());
323 
324       if (!NativeFormatEnabled()) {
325         CommonTestUtilities<T> test_util;
326         test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5),
327                                     output);
328 
329         CommonTestUtilities<T> test_util_mean;
330         test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6),
331                                          batch_mean);
332 
333         CommonTestUtilities<T> test_util_var;
334         test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7),
335                                         batch_var);
336       } else {
337         *output = *GetOutput(0);
338         *batch_mean = *GetOutput(1);
339         *batch_var = *GetOutput(2);
340       }
341     };
342 
343     CommonTestUtilities<T>::VerifyTensorsClose(exponential_avg_factor,
344                                                is_training, run, run_mkl);
345   }
346 
VerifyFusedBatchNormGradWithConv2D(const float epsilon)347   void VerifyFusedBatchNormGradWithConv2D(const float epsilon) {
348 #ifdef ENABLE_MKL
349     // This test only runs with MKL blocked format.
350     const GraphRunnerGrad run =
351         [this](const Tensor& input, const Tensor& filter,
352                const Tensor& y_backprop, const Tensor& scale,
353                const Tensor& mean, const Tensor& variance,
354                const Tensor& res_sp3, Tensor* x_backprop_tensor,
355                Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
356                const float epsilon) {
357           auto root = tensorflow::Scope::NewRootScope();
358 
359           auto input_op =
360               ops::Const(root.WithOpName("input"), Input::Initializer(input));
361           auto filter_op =
362               ops::Const(root.WithOpName("filter"), Input::Initializer(filter));
363           ops::Conv2D::Attrs conv_attr;
364           conv_attr = conv_attr.DataFormat("NHWC");
365           auto conv = ops::Conv2D(root.WithOpName("Conv"), input_op, filter_op,
366                                   {1, 1, 1, 1}, "SAME", conv_attr);
367           // -------------------------------------------------------------
368           auto y_backprop_op = ops::Const(root.WithOpName("y_backprop"),
369                                           Input::Initializer(y_backprop));
370           auto scale_op =
371               ops::Const(root.WithOpName("scale"), Input::Initializer(scale));
372           auto mean_op =
373               ops::Const(root.WithOpName("mean"), Input::Initializer(mean));
374           auto var_op = ops::Const(root.WithOpName("variance"),
375                                    Input::Initializer(variance));
376           auto res_sp3_op = ops::Const(root.WithOpName("reserve_space_3"),
377                                        Input::Initializer(res_sp3));
378           ops::FusedBatchNormGradV3::Attrs bn_attr;
379           bn_attr = bn_attr.IsTraining(true);
380           bn_attr = bn_attr.Epsilon(epsilon);
381           bn_attr = bn_attr.DataFormat("NHWC");
382           auto bn = ops::FusedBatchNormGradV3(
383               root.WithOpName("FusedBatchNormGrad"), y_backprop_op, conv,
384               scale_op, mean_op, var_op, res_sp3_op, bn_attr);
385 
386           auto x_backprop =
387               ops::Identity(root.WithOpName("x_backprop"), bn.x_backprop);
388           auto scale_backprop = ops::Identity(root.WithOpName("scale_backprop"),
389                                               bn.scale_backprop);
390           auto offset_backprop = ops::Identity(
391               root.WithOpName("offset_backprop"), bn.offset_backprop);
392 
393           tensorflow::GraphDef graph;
394           TF_ASSERT_OK(root.ToGraphDef(&graph));
395 
396           tensorflow::SessionOptions session_options;
397           std::unique_ptr<tensorflow::Session> session(
398               tensorflow::NewSession(session_options));
399           TF_ASSERT_OK(session->Create(graph));
400 
401           std::vector<Tensor> output_tensors;
402           TF_ASSERT_OK(session->Run(
403               {}, {"x_backprop", "scale_backprop", "offset_backprop"}, {},
404               &output_tensors));
405 
406           *x_backprop_tensor = output_tensors[0];
407           *scale_backprop_tensor = output_tensors[1];
408           *offset_backprop_tensor = output_tensors[2];
409         };
410 
411     const GraphRunnerGrad run_mkl =
412         [this](const Tensor& input, const Tensor& filter,
413                const Tensor& y_backprop, const Tensor& scale,
414                const Tensor& mean, const Tensor& variance,
415                const Tensor& res_sp3, Tensor* x_backprop_tensor,
416                Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
417                const float epsilon) {
418           Tensor conv2d_output, conv2d_meta_output;
419           Conv2DOpTest<T> conv2d_test;
420           conv2d_test.RunConv2D(input, filter, &conv2d_output,
421                                 &conv2d_meta_output);
422 
423           DataType dtype = DataTypeToEnum<T>::v();
424           TF_EXPECT_OK(
425               NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNormGradV3")
426                   .Input(FakeInput(dtype))
427                   .Input(FakeInput(dtype))
428                   .Input(FakeInput(DT_FLOAT))
429                   .Input(FakeInput(DT_FLOAT))
430                   .Input(FakeInput(DT_FLOAT))
431                   .Input(FakeInput(DT_FLOAT))
432                   .Input(FakeInput(DT_UINT8))
433                   .Input(FakeInput(DT_UINT8))
434                   .Input(FakeInput(DT_UINT8))
435                   .Input(FakeInput(DT_UINT8))
436                   .Input(FakeInput(DT_UINT8))
437                   .Input(FakeInput(DT_UINT8))
438                   .Attr("epsilon", epsilon)
439                   .Attr("is_training", true)
440                   .Attr("data_format", "NHWC")
441                   .Attr("_kernel", "MklLayoutDependentOp")
442                   .Finalize(node_def()));
443           TF_EXPECT_OK(InitOp());
444 
445           AddInputFromArray<T>(y_backprop.shape(), y_backprop.flat<T>());
446           AddInputFromArray<T>(conv2d_output.shape(), conv2d_output.flat<T>());
447           AddInputFromArray<float>(scale.shape(), scale.flat<float>());
448           AddInputFromArray<float>(mean.shape(), mean.flat<float>());
449           AddInputFromArray<float>(variance.shape(), variance.flat<float>());
450           AddInputFromArray<float>(res_sp3.shape(), res_sp3.flat<float>());
451           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
452           AddInputFromArray<uint8>(conv2d_meta_output.shape(),
453                                    conv2d_meta_output.flat<uint8>());
454           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
455           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
456           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
457           AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
458           TF_ASSERT_OK(RunOpKernel());
459 
460           CommonTestUtilities<T> test_util;
461           test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5),
462                                       x_backprop_tensor);
463 
464           CommonTestUtilities<T> test_util_mean;
465           test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6),
466                                            scale_backprop_tensor);
467 
468           CommonTestUtilities<T> test_util_var;
469           test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7),
470                                           offset_backprop_tensor);
471         };
472 
473     CommonTestUtilities<T>::VerifyTensorsCloseForGrad(epsilon, run, run_mkl);
474 #endif  // ENABLE_MKL
475   }
476 };
477 
478 TYPED_TEST_SUITE_P(FusedBatchNormOpTest);
479 
TYPED_TEST_P(FusedBatchNormOpTest,Training)480 TYPED_TEST_P(FusedBatchNormOpTest, Training) {
481   const float exponential_avg_factor = 1.0;
482   const bool is_training = true;
483   this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
484 }
485 
TYPED_TEST_P(FusedBatchNormOpTest,TrainingRunningMean)486 TYPED_TEST_P(FusedBatchNormOpTest, TrainingRunningMean) {
487   const float exponential_avg_factor = 0.5;
488   const bool is_training = true;
489   this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
490 }
491 
TYPED_TEST_P(FusedBatchNormOpTest,Inference)492 TYPED_TEST_P(FusedBatchNormOpTest, Inference) {
493   const float exponential_avg_factor = 1.0;
494   const bool is_training = false;
495   this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
496 }
497 
TYPED_TEST_P(FusedBatchNormOpTest,InferenceIgnoreAvgFactor)498 TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
499   const float exponential_avg_factor = 0.5;
500   const bool is_training = false;
501   this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
502 }
503 
TYPED_TEST_P(FusedBatchNormOpTest,FusedBatchNormGradV3)504 TYPED_TEST_P(FusedBatchNormOpTest, FusedBatchNormGradV3) {
505   const float epsilon = 0.001;
506   this->VerifyFusedBatchNormGradWithConv2D(epsilon);
507 }
508 
509 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
510                             Inference, InferenceIgnoreAvgFactor,
511                             FusedBatchNormGradV3);
512 
513 using FusedBatchNormDataTypes = ::testing::Types<float>;
514 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest,
515                                FusedBatchNormDataTypes);
516 
517 }  // namespace tensorflow
518 
519 #endif  // INTEL_MKL
520