• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/nn_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
34 #include "tensorflow/core/public/session.h"
35 
36 #if GOOGLE_CUDA
37 #include "third_party/gpus/cudnn/cudnn.h"
38 #endif  // GOOGLE_CUDA
39 
40 namespace tensorflow {
41 
42 template <typename T, typename U>
43 class FusedBatchNormExOpTestBase : public OpsTestBase {
44  public:
FusedBatchNormExOpTestBase()45   FusedBatchNormExOpTestBase() {
46     setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
47   }
48 
49  protected:
50   struct FusedBatchNormOutputs {
51     Tensor y;
52     Tensor batch_mean;
53     Tensor batch_variance;
54     Tensor reserve_space_1;
55     Tensor reserve_space_2;
56     Tensor reserve_space_3;
57   };
58 
59   struct FusedBatchNormGradOutputs {
60     Tensor y_backprop;
61     Tensor x_backprop;
62     Tensor scale_backprop;
63     Tensor offset_backprop;
64     Tensor reserve_space_4;
65     Tensor reserve_space_5;
66   };
67 
68   using GraphRunner = std::function<void(
69       const Tensor& y_backprop, const Tensor& input_data,
70       const Tensor& scale_data, const Tensor& offset_data,
71       const Tensor& mean_data, const Tensor& var_data,
72       const Tensor& side_input_data, FusedBatchNormOutputs* forward,
73       FusedBatchNormGradOutputs* backward)>;
74 
75   // Runs a Tensorflow graph defined by the root scope, and fetches the result
76   // of 'fetch' node into the outputs. Optional `add_nodes` parameter
77   // allows to define nodes directly using a NodeDef for the ops that are
78   // not supported by the C++ Api.
79   // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
80   // Add a base class for all FusedABC kernels and remove code duplication.
RunAndFetch(const tensorflow::Scope & root,const std::vector<string> & fetch,std::vector<Tensor> * outputs,bool allow_gpu_device,const std::vector<const NodeDef * > add_nodes={})81   void RunAndFetch(const tensorflow::Scope& root,
82                    const std::vector<string>& fetch,
83                    std::vector<Tensor>* outputs, bool allow_gpu_device,
84                    const std::vector<const NodeDef*> add_nodes = {}) {
85     tensorflow::GraphDef graph;
86     TF_ASSERT_OK(root.ToGraphDef(&graph));
87 
88     for (const NodeDef* add_node : add_nodes) {
89       *graph.add_node() = *add_node;
90     }
91 
92     // We really want to make sure that graph executed exactly as we passed it
93     // to the session, so we disable various optimizations.
94     tensorflow::SessionOptions session_options;
95 
96     // Disable common runtime constant folding.
97     session_options.config.mutable_graph_options()
98         ->mutable_optimizer_options()
99         ->set_opt_level(OptimizerOptions::L0);
100 
101     // Disable Grappler optimizations for tests.
102     tensorflow::RewriterConfig* cfg =
103         session_options.config.mutable_graph_options()
104             ->mutable_rewrite_options();
105     cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
106     cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
107     cfg->set_remapping(tensorflow::RewriterConfig::OFF);
108 
109     std::unique_ptr<tensorflow::Session> session(
110         tensorflow::NewSession(session_options));
111 
112     std::vector<DeviceAttributes> available_devices;
113     TF_ASSERT_OK(session->ListDevices(&available_devices))
114         << "Failed to get available session devices";
115 
116     // Check if session has an available GPU device.
117     const bool has_gpu_device =
__anon352582dc0102(const DeviceAttributes& device) 118         absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
119           return device.device_type() == DEVICE_GPU;
120         });
121 
122     // Some of the `FusedABC` ops are implemented only for CPU, and in this test
123     // we don't want to compare GPU vs CPU numbers, so place all nodes on CPU in
124     // this case.
125     const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
126 
127     const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
128     for (NodeDef& mutable_node : *graph.mutable_node()) {
129       mutable_node.set_device(device);
130     }
131 
132     TF_ASSERT_OK(session->Create(graph));
133     TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
134   }
135 
RunFusedBatchNorm(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)136   void RunFusedBatchNorm(const Tensor& y_backprop_data,
137                          const Tensor& input_data, const Tensor& scale_data,
138                          const Tensor& offset_data, const Tensor& mean_data,
139                          const Tensor& var_data, const Tensor& side_input_data,
140                          const TensorFormat data_format, bool is_training,
141                          bool has_side_input, const string& activation_mode,
142                          FusedBatchNormOutputs* forward,
143                          FusedBatchNormGradOutputs* backward,
144                          float epsilon = 0.1f) {
145     Scope root = tensorflow::Scope::NewRootScope();
146 
147     Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
148                                    Input::Initializer(y_backprop_data));
149     Output input =
150         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
151     Output scale =
152         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
153     Output offset =
154         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
155     Output mean =
156         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
157     Output var =
158         ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
159     Output side_input = ops::Const(root.WithOpName("side_input"),
160                                    Input::Initializer(side_input_data));
161 
162     ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
163         root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
164         ops::FusedBatchNormV3::IsTraining(is_training)
165             .Epsilon(epsilon)
166             .DataFormat(ToString(data_format)));
167 
168     Output with_side_input;
169     if (has_side_input) {
170       with_side_input =
171           ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
172     } else {
173       with_side_input =
174           ops::Identity(root.WithOpName("with_side_input"), fwd.y);
175     }
176 
177     Output activation;
178     if (activation_mode == "Relu") {
179       activation =
180           ops::Relu(root.WithOpName("with_activation"), with_side_input);
181     } else {
182       activation =
183           ops::Identity(root.WithOpName("with_activation"), with_side_input);
184     }
185 
186     Output activation_grad;
187     if (activation_mode == "Relu") {
188       activation_grad = ops::internal::ReluGrad(
189           root.WithOpName("activation_grad"), y_backprop, activation);
190     } else {
191       activation_grad =
192           ops::Identity(root.WithOpName("activation_grad"), y_backprop);
193     }
194 
195     ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
196         root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
197         fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
198         ops::FusedBatchNormGradV3::IsTraining(is_training)
199             .Epsilon(epsilon)
200             .DataFormat(ToString(data_format)));
201 
202     std::vector<Tensor> out_tensors;
203     RunAndFetch(
204         root,
205         {"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
206          "fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
207          "activation_grad:0", "fused_batch_norm_grad:0",
208          "fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
209         &out_tensors, /*allow_gpu_device=*/true);
210 
211     forward->y = out_tensors[0];
212     forward->batch_mean = out_tensors[1];
213     forward->batch_variance = out_tensors[2];
214     forward->reserve_space_1 = out_tensors[3];
215     forward->reserve_space_2 = out_tensors[4];
216     forward->reserve_space_3 = out_tensors[5];
217 
218     backward->y_backprop = out_tensors[6];
219     backward->x_backprop = out_tensors[7];
220     backward->scale_backprop = out_tensors[8];
221     backward->offset_backprop = out_tensors[9];
222   }
223 
RunFusedBatchNormEx(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)224   void RunFusedBatchNormEx(const Tensor& y_backprop_data,
225                            const Tensor& input_data, const Tensor& scale_data,
226                            const Tensor& offset_data, const Tensor& mean_data,
227                            const Tensor& var_data,
228                            const Tensor& side_input_data,
229                            const TensorFormat data_format, bool is_training,
230                            bool has_side_input, const string& activation_mode,
231                            FusedBatchNormOutputs* forward,
232                            FusedBatchNormGradOutputs* backward,
233                            float epsilon = 0.1f) {
234     Scope root = tensorflow::Scope::NewRootScope();
235 
236     DataType t_dtype = DataTypeToEnum<T>::v();
237     DataType u_dtype = DataTypeToEnum<U>::v();
238 
239     Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
240                                    Input::Initializer(y_backprop_data));
241     Output input =
242         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
243     Output scale =
244         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
245     Output offset =
246         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
247     Output mean =
248         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
249     Output var =
250         ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
251     Output side_input = ops::Const(root.WithOpName("side_input"),
252                                    Input::Initializer(side_input_data));
253     Output empty =
254         ops::Const(root.WithOpName("empty"),
255                    Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
256 
257     int num_side_inputs = 0;
258     std::vector<NodeDefBuilder::NodeOut> side_inputs;
259 
260     if (has_side_input) {
261       num_side_inputs = 1;
262       side_inputs.push_back({side_input.name(), 0, t_dtype});
263     }
264 
265     NodeDef fused_batch_norm_ex;
266     TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_ex", "_FusedBatchNormEx")
267                      .Input({input.name(), 0, t_dtype})
268                      .Input({scale.name(), 0, u_dtype})
269                      .Input({offset.name(), 0, u_dtype})
270                      .Input({mean.name(), 0, u_dtype})
271                      .Input({var.name(), 0, u_dtype})
272                      .Input(side_inputs)
273                      .Attr("T", t_dtype)
274                      .Attr("U", u_dtype)
275                      .Attr("data_format", ToString(data_format))
276                      .Attr("epsilon", epsilon)
277                      .Attr("activation_mode", activation_mode)
278                      .Attr("num_side_inputs", num_side_inputs)
279                      .Attr("is_training", is_training)
280                      .Finalize(&fused_batch_norm_ex));
281 
282     NodeDef activation_grad;
283     if (activation_mode == "Relu") {
284       TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
285                        .Input({y_backprop.name(), 0, t_dtype})
286                        .Input({fused_batch_norm_ex.name(), 0, t_dtype})
287                        .Attr("T", t_dtype)
288                        .Finalize(&activation_grad));
289     } else {
290       TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
291                        .Input({y_backprop.name(), 0, t_dtype})
292                        .Attr("T", t_dtype)
293                        .Finalize(&activation_grad));
294     }
295 
296     NodeDef fused_batch_norm_grad;
297     TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
298                      .Input({activation_grad.name(), 0, t_dtype})
299                      .Input({input.name(), 0, t_dtype})
300                      .Input({scale.name(), 0, u_dtype})
301                      .Input({fused_batch_norm_ex.name(), 3, u_dtype})
302                      .Input({fused_batch_norm_ex.name(), 4, u_dtype})
303                      .Input({fused_batch_norm_ex.name(), 5, u_dtype})
304                      .Attr("T", t_dtype)
305                      .Attr("U", u_dtype)
306                      .Attr("data_format", ToString(data_format))
307                      .Attr("epsilon", epsilon)
308                      .Attr("is_training", is_training)
309                      .Finalize(&fused_batch_norm_grad));
310 
311     std::vector<Tensor> out_tensors;
312     RunAndFetch(
313         root,
314         {"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
315          "fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
316          "fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
317          "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
318          "fused_batch_norm_grad:2"},
319         &out_tensors,
320         /*allow_gpu_device=*/true,
321         {&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
322 
323     forward->y = out_tensors[0];
324     forward->batch_mean = out_tensors[1];
325     forward->batch_variance = out_tensors[2];
326     forward->reserve_space_1 = out_tensors[3];
327     forward->reserve_space_2 = out_tensors[4];
328     forward->reserve_space_3 = out_tensors[5];
329 
330     backward->y_backprop = out_tensors[6];
331     backward->x_backprop = out_tensors[7];
332     backward->scale_backprop = out_tensors[8];
333     backward->offset_backprop = out_tensors[9];
334   }
335 
VerifyTensorsNear(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,const GraphRunner & run_default,const GraphRunner & run_fused)336   void VerifyTensorsNear(int batch, int height, int width, int channels,
337                          TensorFormat data_format, bool is_training,
338                          const GraphRunner& run_default,
339                          const GraphRunner& run_fused) {
340     DataType t_dtype = DataTypeToEnum<T>::v();
341     DataType u_dtype = DataTypeToEnum<U>::v();
342 
343     TensorShape input_shape =
344         data_format == FORMAT_NHWC
345             ? TensorShape({batch, height, width, channels})
346             : TensorShape({batch, channels, height, width});
347 
348     Tensor input(t_dtype, input_shape);
349     input.flat<T>().setRandom();
350     input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
351 
352     Tensor scale(u_dtype, {channels});
353     scale.flat<U>().setRandom();
354 
355     Tensor offset(u_dtype, {channels});
356     offset.flat<U>().setRandom();
357 
358     Tensor mean(u_dtype, {channels});
359     mean.flat<U>().setRandom();
360 
361     Tensor var(u_dtype, {channels});
362     var.flat<U>().setRandom();
363 
364     Tensor side_input(t_dtype, input_shape);
365     side_input.flat<T>().setRandom();
366     side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
367 
368     Tensor y_backprop(t_dtype, input_shape);
369     y_backprop.flat<T>().setRandom();
370     y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
371 
372     Tensor empty(u_dtype, {0});
373 
374     FusedBatchNormOutputs fbn_forward;
375     FusedBatchNormOutputs fbn_ex_forward;
376 
377     FusedBatchNormGradOutputs fbn_backward;
378     FusedBatchNormGradOutputs fbn_ex_backward;
379 
380     run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
381                 is_training ? empty : var, side_input, &fbn_forward,
382                 &fbn_backward);
383 
384     // Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
385     // make sure that fused kernel actually writes correct results to memory.
386     run_default(y_backprop, side_input, scale, offset,
387                 is_training ? empty : mean, is_training ? empty : var, input,
388                 &fbn_ex_forward, &fbn_ex_backward);
389 
390     run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
391               is_training ? empty : var, side_input, &fbn_ex_forward,
392               &fbn_ex_backward);
393 
394     std::vector<std::pair<Tensor, Tensor>> tensor_pairs;
395     if (is_training) {
396       tensor_pairs = {
397           {fbn_forward.y, fbn_ex_forward.y},
398           {fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
399           {fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
400           {fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
401           {fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
402           // NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
403           // because BatchNormEx with fused side input has different data in it,
404           // but we make sure that final gradients are the same.
405           {fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
406           {fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
407           {fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
408           {fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
409       };
410     } else {
411       tensor_pairs = {{fbn_forward.y, fbn_ex_forward.y}};
412     }
413 
414     for (auto& pair : tensor_pairs) {
415       const Tensor& fbn = pair.first;
416       const Tensor& fbn_ex = pair.second;
417 
418       ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
419       ASSERT_EQ(fbn.shape(), fbn_ex.shape());
420 
421       test::ExpectClose(fbn, fbn_ex, 1e-2);
422     }
423   }
424 
425   // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
426   // identical to FusedBatchNormExOp[fused_ops={SideInput, Activation}].
VerifyFusedBatchNormEx(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode)427   void VerifyFusedBatchNormEx(int batch, int height, int width, int channels,
428                               TensorFormat data_format, bool is_training,
429                               bool has_side_input,
430                               const string& activation_mode) {
431     const GraphRunner run_default =
432         [&](const Tensor& y_backprop, const Tensor& input_data,
433             const Tensor& scale_data, const Tensor& offset_data,
434             const Tensor& mean_data, const Tensor& var_data,
435             const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
436             FusedBatchNormGradOutputs* bwd) {
437           this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
438                                   offset_data, mean_data, var_data,
439                                   side_input_data, data_format, is_training,
440                                   has_side_input, activation_mode, fwd, bwd);
441         };
442 
443     const GraphRunner run_inference =
444         [&](const Tensor& y_backprop, const Tensor& input_data,
445             const Tensor& scale_data, const Tensor& offset_data,
446             const Tensor& mean_data, const Tensor& var_data,
447             const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
448             FusedBatchNormGradOutputs* bwd) {
449           this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
450                                     offset_data, mean_data, var_data,
451                                     side_input_data, data_format, is_training,
452                                     has_side_input, activation_mode, fwd, bwd);
453         };
454 
455     VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
456                       run_default, run_inference);
457   }
458 };
459 
460 constexpr bool kInTraining = true;     // is_training == true
461 constexpr bool kInInference = false;   // is_training == false
462 constexpr bool kNoSideInput = false;   // side_input == false
463 constexpr bool kWithSideInput = true;  // side_input == true
464 
465 // -------------------------------------------------------------------------- //
466 // FusedBatchNormEx[is_training=true].
467 
468 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
469 template <typename T>
470 using FusedBatchNormExOpTrainingTest =
471     FusedBatchNormExOpTestBase<T, float>;  // scale is always float
472 
473 TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest);
474 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingInNHWCTest)475 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingInNHWCTest) {
476   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
477                                kNoSideInput, "Identity");
478 }
479 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithReluInNHWCTest)480 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingWithReluInNHWCTest) {
481   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
482                                kNoSideInput, "Relu");
483 }
484 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithSideInputAndReluInNHWCTest)485 TYPED_TEST_P(FusedBatchNormExOpTrainingTest,
486              TrainingWithSideInputAndReluInNHWCTest) {
487   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
488                                kWithSideInput, "Relu");
489 }
490 
491 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest,  //
492                             TrainingInNHWCTest,              //
493                             TrainingWithReluInNHWCTest,      //
494                             TrainingWithSideInputAndReluInNHWCTest);
495 
496 using FusedBatchNormExTrainingDataTypes = ::testing::Types<Eigen::half>;
497 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpTrainingTest,
498                                FusedBatchNormExTrainingDataTypes);
499 #endif  // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
500 
501 // -------------------------------------------------------------------------- //
502 // FusedBatchNormEx[is_training=false].
503 
504 #if defined(GOOGLE_CUDA)
505 template <typename T>
506 using FusedBatchNormExOpInferenceTest =
507     FusedBatchNormExOpTestBase<T, float>;  // scale is always float
508 
509 TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest);
510 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceInNHWCTest)511 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceInNHWCTest) {
512   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
513                                kNoSideInput, "Identity");
514 }
515 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithReluInNHWCTest)516 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceWithReluInNHWCTest) {
517   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
518                                kNoSideInput, "Relu");
519 }
520 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithSideInputAndReluInNHWCTest)521 TYPED_TEST_P(FusedBatchNormExOpInferenceTest,
522              InferenceWithSideInputAndReluInNHWCTest) {
523   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
524                                kWithSideInput, "Relu");
525 }
526 
527 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest,  //
528                             InferenceInNHWCTest,              //
529                             InferenceWithReluInNHWCTest,      //
530                             InferenceWithSideInputAndReluInNHWCTest);
531 
532 using FusedBatchNormExInferenceDataTypes = ::testing::Types<Eigen::half, float>;
533 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpInferenceTest,
534                                FusedBatchNormExInferenceDataTypes);
535 #endif  // defined(GOOGLE_CUDA)
536 
537 // -------------------------------------------------------------------------- //
538 // Performance benchmarks are below.                                          //
539 // -------------------------------------------------------------------------- //
540 
541 using fp16 = Eigen::half;
542 using fp32 = float;
543 using SideInputAndActivation = std::pair<bool, string>;
544 
Identity()545 SideInputAndActivation Identity() { return {false, "Identity"}; }
Relu()546 SideInputAndActivation Relu() { return {false, "Relu"}; }
AddAndRelu()547 SideInputAndActivation AddAndRelu() { return {true, "Relu"}; }
548 
549 template <typename T>
FusedBatchNormEx(int n,int h,int w,int c,TensorFormat data_format,bool is_training,std::function<SideInputAndActivation ()> fn)550 static Graph* FusedBatchNormEx(int n, int h, int w, int c,
551                                TensorFormat data_format, bool is_training,
552                                std::function<SideInputAndActivation()> fn) {
553   Graph* g = new Graph(OpRegistry::Global());
554 
555   DataType dtype = DataTypeToEnum<T>::value;
556   Tensor x_t(dtype, data_format == FORMAT_NHWC ? TensorShape({n, h, w, c})
557                                                : TensorShape({n, c, h, w}));
558   x_t.flat<T>().setRandom();
559 
560   Tensor other_t(DT_FLOAT, TensorShape({c}));
561   other_t.flat<float>().setRandom();
562 
563   Node* x = test::graph::Constant(g, x_t, "x");
564   Node* other = test::graph::Constant(g, other_t, "other");
565   Node* empty = test::graph::Constant(g, Tensor(DT_FLOAT, {0}), "empty");
566 
567   int num_side_inputs = 0;
568   std::vector<NodeBuilder::NodeOut> side_inputs;
569 
570   SideInputAndActivation side_input_and_activation = fn();
571   bool has_side_input = side_input_and_activation.first;
572   string activation_mode = side_input_and_activation.second;
573 
574   if (has_side_input) {
575     num_side_inputs = 1;
576     side_inputs.push_back({x});
577   }
578 
579   Node* fused_batch_norm;
580   TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "_FusedBatchNormEx")
581                   .Input(x)
582                   .Input(other)                        // scale
583                   .Input(other)                        // offset
584                   .Input(is_training ? empty : other)  // mean
585                   .Input(is_training ? empty : other)  // variance
586                   .Input(side_inputs)                  // side_input
587                   .Attr("T", dtype)
588                   .Attr("U", DT_FLOAT)
589                   .Attr("epsilon", 0.001)
590                   .Attr("data_format", ToString(data_format))
591                   .Attr("activation_mode", activation_mode)
592                   .Attr("num_side_inputs", num_side_inputs)
593                   .Attr("is_training", is_training)
594                   .Finalize(g, &fused_batch_norm));
595 
596   return g;
597 }
598 
599 #define BM_CONCAT(a, b) a##_##b
600 
601 #define BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, A, DEVICE)          \
602   BM_CONCAT(BM_FusedBatchNorm##_##DEVICE##_##T##_##N##_##H##_##W##_##C, \
603             FORMAT##_##IS_TRAINING##_##A)
604 
605 #define BM_FusedBatchNorm(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION,    \
606                           DEVICE)                                            \
607   static void BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION,        \
608                       DEVICE)(::testing::benchmark::State & state) {         \
609     test::Benchmark(#DEVICE,                                                 \
610                     FusedBatchNormEx<T>(N, H, W, C, FORMAT_##FORMAT,         \
611                                         IS_TRAINING, {ACTIVATION}),          \
612                     /*old_benchmark_api*/ false)                             \
613         .Run(state);                                                         \
614     state.SetItemsProcessed(state.iterations() * N * H * W * C);             \
615   }                                                                          \
616   BENCHMARK(BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, DEVICE)) \
617       ->UseRealTime();
618 
619 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
620 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Identity, gpu);
621 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Relu, gpu);
622 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, AddAndRelu, gpu);
623 #endif  // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
624 
625 #if defined(GOOGLE_CUDA)
626 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Identity, gpu);
627 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Relu, gpu);
628 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, AddAndRelu, gpu);
629 
630 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Identity, gpu);
631 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Relu, gpu);
632 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, AddAndRelu, gpu);
633 
634 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Identity, gpu);
635 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Relu, gpu);
636 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, AddAndRelu, gpu);
637 
638 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Identity, gpu);
639 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Relu, gpu);
640 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, AddAndRelu, gpu);
641 #endif  // defined(GOOGLE_CUDA)
642 
643 }  // namespace tensorflow
644