1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/nn_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
34 #include "tensorflow/core/public/session.h"
35 
36 #if GOOGLE_CUDA
37 #include "third_party/gpus/cudnn/cudnn.h"
38 #endif  // GOOGLE_CUDA
39 
40 namespace tensorflow {
41 
42 template <typename T, typename U>
43 class FusedBatchNormExOpTestBase : public OpsTestBase {
44  public:
FusedBatchNormExOpTestBase()45   FusedBatchNormExOpTestBase() {
46     setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
47   }
48 
49  protected:
50   struct FusedBatchNormOutputs {
51     Tensor y;
52     Tensor batch_mean;
53     Tensor batch_variance;
54     Tensor reserve_space_1;
55     Tensor reserve_space_2;
56     Tensor reserve_space_3;
57   };
58 
59   struct FusedBatchNormGradOutputs {
60     Tensor y_backprop;
61     Tensor x_backprop;
62     Tensor scale_backprop;
63     Tensor offset_backprop;
64     Tensor reserve_space_4;
65     Tensor reserve_space_5;
66   };
67 
68   using GraphRunner = std::function<void(
69       const Tensor& y_backprop, const Tensor& input_data,
70       const Tensor& scale_data, const Tensor& offset_data,
71       const Tensor& mean_data, const Tensor& var_data,
72       const Tensor& side_input_data, FusedBatchNormOutputs* forward,
73       FusedBatchNormGradOutputs* backward)>;
74 
75   // Runs a Tensorflow graph defined by the root scope, and fetches the result
76   // of 'fetch' node into the outputs. Optional `add_nodes` parameter
77   // allows to define nodes directly using a NodeDef for the ops that are
78   // not supported by the C++ Api.
79   // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
80   // Add a base class for all FusedABC kernels and remove code duplication.
RunAndFetch(const tensorflow::Scope & root,const std::vector<string> & fetch,std::vector<Tensor> * outputs,bool allow_gpu_device,const std::vector<const NodeDef * > add_nodes={})81   void RunAndFetch(const tensorflow::Scope& root,
82                    const std::vector<string>& fetch,
83                    std::vector<Tensor>* outputs, bool allow_gpu_device,
84                    const std::vector<const NodeDef*> add_nodes = {}) {
85     tensorflow::GraphDef graph;
86     TF_ASSERT_OK(root.ToGraphDef(&graph));
87 
88     for (const NodeDef* add_node : add_nodes) {
89       *graph.add_node() = *add_node;
90     }
91 
92     // We really want to make sure that graph executed exactly as we passed it
93     // to the session, so we disable various optimizations.
94     tensorflow::SessionOptions session_options;
95 
96     // Disable common runtime constant folding.
97     session_options.config.mutable_graph_options()
98         ->mutable_optimizer_options()
99         ->set_opt_level(OptimizerOptions::L0);
100 
101     // Disable Grappler optimizations for tests.
102     tensorflow::RewriterConfig* cfg =
103         session_options.config.mutable_graph_options()
104             ->mutable_rewrite_options();
105     cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
106     cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
107     cfg->set_remapping(tensorflow::RewriterConfig::OFF);
108 
109     std::unique_ptr<tensorflow::Session> session(
110         tensorflow::NewSession(session_options));
111 
112     std::vector<DeviceAttributes> available_devices;
113     TF_ASSERT_OK(session->ListDevices(&available_devices))
114         << "Failed to get available session devices";
115 
116     // Check if session has an available GPU device.
117     const bool has_gpu_device =
__anoncb8c21e60102(const DeviceAttributes& device) 118         absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
119           return device.device_type() == DEVICE_GPU;
120         });
121 
122     // Some of the `FusedABC` ops are implemented only for CPU, and in this test
123     // we don't want to compare GPU vs CPU numbers, so place all nodes on CPU in
124     // this case.
125     const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
126 
127     const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
128     for (NodeDef& mutable_node : *graph.mutable_node()) {
129       mutable_node.set_device(device);
130     }
131 
132     TF_ASSERT_OK(session->Create(graph));
133     TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
134   }
135 
RunFusedBatchNorm(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)136   void RunFusedBatchNorm(const Tensor& y_backprop_data,
137                          const Tensor& input_data, const Tensor& scale_data,
138                          const Tensor& offset_data, const Tensor& mean_data,
139                          const Tensor& var_data, const Tensor& side_input_data,
140                          const TensorFormat data_format, bool is_training,
141                          bool has_side_input, const string& activation_mode,
142                          FusedBatchNormOutputs* forward,
143                          FusedBatchNormGradOutputs* backward,
144                          float epsilon = 0.1f) {
145     Scope root = tensorflow::Scope::NewRootScope();
146 
147     Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
148                                    Input::Initializer(y_backprop_data));
149     Output input =
150         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
151     Output scale =
152         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
153     Output offset =
154         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
155     Output mean =
156         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
157     Output var =
158         ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
159     Output side_input = ops::Const(root.WithOpName("side_input"),
160                                    Input::Initializer(side_input_data));
161 
162     ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
163         root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
164         ops::FusedBatchNormV3::IsTraining(is_training)
165             .Epsilon(epsilon)
166             .DataFormat(ToString(data_format)));
167 
168     Output with_side_input;
169     if (has_side_input) {
170       with_side_input =
171           ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
172     } else {
173       with_side_input =
174           ops::Identity(root.WithOpName("with_side_input"), fwd.y);
175     }
176 
177     Output activation;
178     if (activation_mode == "Relu") {
179       activation =
180           ops::Relu(root.WithOpName("with_activation"), with_side_input);
181     } else {
182       activation =
183           ops::Identity(root.WithOpName("with_activation"), with_side_input);
184     }
185 
186     Output activation_grad;
187     if (activation_mode == "Relu") {
188       activation_grad = ops::internal::ReluGrad(
189           root.WithOpName("activation_grad"), y_backprop, activation);
190     } else {
191       activation_grad =
192           ops::Identity(root.WithOpName("activation_grad"), y_backprop);
193     }
194 
195     ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
196         root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
197         fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
198         ops::FusedBatchNormGradV3::IsTraining(is_training)
199             .Epsilon(epsilon)
200             .DataFormat(ToString(data_format)));
201 
202     std::vector<Tensor> out_tensors;
203     RunAndFetch(
204         root,
205         {"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
206          "fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
207          "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
208          "fused_batch_norm_grad:2"},
209         &out_tensors, /*allow_gpu_device=*/true);
210 
211     forward->y = out_tensors[0];
212     forward->batch_mean = out_tensors[1];
213     forward->batch_variance = out_tensors[2];
214     forward->reserve_space_1 = out_tensors[3];
215     forward->reserve_space_2 = out_tensors[4];
216     forward->reserve_space_3 = out_tensors[5];
217 
218     backward->x_backprop = out_tensors[6];
219     backward->scale_backprop = out_tensors[7];
220     backward->offset_backprop = out_tensors[8];
221   }
222 
RunFusedBatchNormEx(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)223   void RunFusedBatchNormEx(const Tensor& y_backprop_data,
224                            const Tensor& input_data, const Tensor& scale_data,
225                            const Tensor& offset_data, const Tensor& mean_data,
226                            const Tensor& var_data,
227                            const Tensor& side_input_data,
228                            const TensorFormat data_format, bool is_training,
229                            bool has_side_input, const string& activation_mode,
230                            FusedBatchNormOutputs* forward,
231                            FusedBatchNormGradOutputs* backward,
232                            float epsilon = 0.1f) {
233     Scope root = tensorflow::Scope::NewRootScope();
234 
235     DataType t_dtype = DataTypeToEnum<T>::v();
236     DataType u_dtype = DataTypeToEnum<U>::v();
237 
238     Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
239                                    Input::Initializer(y_backprop_data));
240     Output input =
241         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
242     Output scale =
243         ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
244     Output offset =
245         ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
246     Output mean =
247         ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
248     Output var =
249         ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
250     Output side_input = ops::Const(root.WithOpName("side_input"),
251                                    Input::Initializer(side_input_data));
252     Output empty =
253         ops::Const(root.WithOpName("empty"),
254                    Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
255 
256     int num_side_inputs = 0;
257     std::vector<NodeDefBuilder::NodeOut> side_inputs;
258 
259     if (has_side_input) {
260       num_side_inputs = 1;
261       side_inputs.push_back({side_input.name(), 0, t_dtype});
262     }
263 
264     NodeDef fused_batch_norm_ex;
265     TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_ex", "_FusedBatchNormEx")
266                      .Input({input.name(), 0, t_dtype})
267                      .Input({scale.name(), 0, u_dtype})
268                      .Input({offset.name(), 0, u_dtype})
269                      .Input({mean.name(), 0, u_dtype})
270                      .Input({var.name(), 0, u_dtype})
271                      .Input(side_inputs)
272                      .Attr("T", t_dtype)
273                      .Attr("U", u_dtype)
274                      .Attr("data_format", ToString(data_format))
275                      .Attr("epsilon", epsilon)
276                      .Attr("activation_mode", activation_mode)
277                      .Attr("num_side_inputs", num_side_inputs)
278                      .Attr("is_training", is_training)
279                      .Finalize(&fused_batch_norm_ex));
280 
281     NodeDef fused_batch_norm_grad;
282     NodeDef activation_grad;
283     std::vector<Tensor> out_tensors;
284     std::vector<const NodeDef*> add_nodes;
285     if (is_training) {
286       TF_EXPECT_OK(
287           NodeDefBuilder("fused_batch_norm_grad", "_FusedBatchNormGradEx")
288               .Input({y_backprop.name(), 0, t_dtype})
289               .Input({input.name(), 0, t_dtype})
290               .Input({scale.name(), 0, u_dtype})
291               .Input({fused_batch_norm_ex.name(), 3, u_dtype})
292               .Input({fused_batch_norm_ex.name(), 4, u_dtype})
293               .Input({fused_batch_norm_ex.name(), 5, u_dtype})
294               .Input({offset.name(), 0, u_dtype})
295               .Input({fused_batch_norm_ex.name(), 0, t_dtype})
296               .Attr("T", t_dtype)
297               .Attr("U", u_dtype)
298               .Attr("data_format", ToString(data_format))
299               .Attr("epsilon", epsilon)
300               .Attr("activation_mode", activation_mode)
301               .Attr("num_side_inputs", num_side_inputs)
302               .Attr("is_training", is_training)
303               .Finalize(&fused_batch_norm_grad));
304       add_nodes = {&fused_batch_norm_ex, &fused_batch_norm_grad};
305     } else {
306       if (activation_mode == "Relu") {
307         TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
308                          .Input({y_backprop.name(), 0, t_dtype})
309                          .Input({fused_batch_norm_ex.name(), 0, t_dtype})
310                          .Attr("T", t_dtype)
311                          .Finalize(&activation_grad));
312       } else {
313         TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
314                          .Input({y_backprop.name(), 0, t_dtype})
315                          .Attr("T", t_dtype)
316                          .Finalize(&activation_grad));
317       }
318       TF_EXPECT_OK(
319           NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
320               .Input({activation_grad.name(), 0, t_dtype})
321               .Input({input.name(), 0, t_dtype})
322               .Input({scale.name(), 0, u_dtype})
323               .Input({fused_batch_norm_ex.name(), 3, u_dtype})
324               .Input({fused_batch_norm_ex.name(), 4, u_dtype})
325               .Input({fused_batch_norm_ex.name(), 5, u_dtype})
326               .Attr("T", t_dtype)
327               .Attr("U", u_dtype)
328               .Attr("data_format", ToString(data_format))
329               .Attr("epsilon", epsilon)
330               .Attr("is_training", is_training)
331               .Finalize(&fused_batch_norm_grad));
332       add_nodes = {&fused_batch_norm_ex, &activation_grad,
333                    &fused_batch_norm_grad};
334     }
335 
336     RunAndFetch(root,
337                 {"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
338                  "fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
339                  "fused_batch_norm_ex:4", "fused_batch_norm_ex:5",
340                  "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
341                  "fused_batch_norm_grad:2"},
342                 &out_tensors,
343                 /*allow_gpu_device=*/true, add_nodes);
344 
345     forward->y = out_tensors[0];
346     forward->batch_mean = out_tensors[1];
347     forward->batch_variance = out_tensors[2];
348     forward->reserve_space_1 = out_tensors[3];
349     forward->reserve_space_2 = out_tensors[4];
350     forward->reserve_space_3 = out_tensors[5];
351 
352     backward->x_backprop = out_tensors[6];
353     backward->scale_backprop = out_tensors[7];
354     backward->offset_backprop = out_tensors[8];
355   }
356 
VerifyTensorsNear(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,const GraphRunner & run_default,const GraphRunner & run_fused)357   void VerifyTensorsNear(int batch, int height, int width, int channels,
358                          TensorFormat data_format, bool is_training,
359                          const GraphRunner& run_default,
360                          const GraphRunner& run_fused) {
361     DataType t_dtype = DataTypeToEnum<T>::v();
362     DataType u_dtype = DataTypeToEnum<U>::v();
363 
364     TensorShape input_shape =
365         data_format == FORMAT_NHWC
366             ? TensorShape({batch, height, width, channels})
367             : TensorShape({batch, channels, height, width});
368 
369     Tensor input(t_dtype, input_shape);
370     input.flat<T>().setRandom();
371     input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
372 
373     Tensor scale(u_dtype, {channels});
374     scale.flat<U>().setRandom();
375 
376     Tensor offset(u_dtype, {channels});
377     offset.flat<U>().setRandom();
378 
379     Tensor mean(u_dtype, {channels});
380     mean.flat<U>().setRandom();
381 
382     Tensor var(u_dtype, {channels});
383     var.flat<U>().setRandom();
384 
385     Tensor side_input(t_dtype, input_shape);
386     side_input.flat<T>().setRandom();
387     side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
388 
389     Tensor y_backprop(t_dtype, input_shape);
390     y_backprop.flat<T>().setRandom();
391     y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
392 
393     Tensor empty(u_dtype, {0});
394 
395     FusedBatchNormOutputs fbn_forward;
396     FusedBatchNormOutputs fbn_ex_forward;
397 
398     FusedBatchNormGradOutputs fbn_backward;
399     FusedBatchNormGradOutputs fbn_ex_backward;
400 
401     run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
402                 is_training ? empty : var, side_input, &fbn_forward,
403                 &fbn_backward);
404 
405     // Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
406     // make sure that fused kernel actually writes correct results to memory.
407     run_default(y_backprop, side_input, scale, offset,
408                 is_training ? empty : mean, is_training ? empty : var, input,
409                 &fbn_ex_forward, &fbn_ex_backward);
410 
411     run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
412               is_training ? empty : var, side_input, &fbn_ex_forward,
413               &fbn_ex_backward);
414 
415     std::vector<std::pair<Tensor, Tensor>> tensor_pairs;
416     if (is_training) {
417       tensor_pairs = {
418           {fbn_forward.y, fbn_ex_forward.y},
419           {fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
420           {fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
421           {fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
422           {fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
423           // NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
424           // because BatchNormEx with fused side input has different data in it,
425           // but we make sure that final gradients are the same.
426           {fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
427           {fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
428           {fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
429           {fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
430       };
431     } else {
432       tensor_pairs = {{fbn_forward.y, fbn_ex_forward.y}};
433     }
434 
435     for (auto& pair : tensor_pairs) {
436       const Tensor& fbn = pair.first;
437       const Tensor& fbn_ex = pair.second;
438 
439       ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
440       ASSERT_EQ(fbn.shape(), fbn_ex.shape());
441 
442       test::ExpectClose(fbn, fbn_ex, 1e-2);
443     }
444   }
445 
446   // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
447   // identical to FusedBatchNormExOp[fused_ops={SideInput, Activation}].
VerifyFusedBatchNormEx(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode)448   void VerifyFusedBatchNormEx(int batch, int height, int width, int channels,
449                               TensorFormat data_format, bool is_training,
450                               bool has_side_input,
451                               const string& activation_mode) {
452     const GraphRunner run_default =
453         [&](const Tensor& y_backprop, const Tensor& input_data,
454             const Tensor& scale_data, const Tensor& offset_data,
455             const Tensor& mean_data, const Tensor& var_data,
456             const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
457             FusedBatchNormGradOutputs* bwd) {
458           this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
459                                   offset_data, mean_data, var_data,
460                                   side_input_data, data_format, is_training,
461                                   has_side_input, activation_mode, fwd, bwd);
462         };
463 
464     const GraphRunner run_inference =
465         [&](const Tensor& y_backprop, const Tensor& input_data,
466             const Tensor& scale_data, const Tensor& offset_data,
467             const Tensor& mean_data, const Tensor& var_data,
468             const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
469             FusedBatchNormGradOutputs* bwd) {
470           this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
471                                     offset_data, mean_data, var_data,
472                                     side_input_data, data_format, is_training,
473                                     has_side_input, activation_mode, fwd, bwd);
474         };
475 
476     VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
477                       run_default, run_inference);
478   }
479 };
480 
481 constexpr bool kInTraining = true;     // is_training == true
482 constexpr bool kInInference = false;   // is_training == false
483 constexpr bool kNoSideInput = false;   // side_input == false
484 constexpr bool kWithSideInput = true;  // side_input == true
485 
486 // -------------------------------------------------------------------------- //
487 // FusedBatchNormEx[is_training=true].
488 
489 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
490 template <typename T>
491 using FusedBatchNormExOpTrainingTest =
492     FusedBatchNormExOpTestBase<T, float>;  // scale is always float
493 
494 TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest);
495 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingInNHWCTest)496 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingInNHWCTest) {
497   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
498                                kNoSideInput, "Identity");
499 }
500 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithReluInNHWCTest)501 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingWithReluInNHWCTest) {
502   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
503                                kNoSideInput, "Relu");
504 }
505 
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithSideInputAndReluInNHWCTest)506 TYPED_TEST_P(FusedBatchNormExOpTrainingTest,
507              TrainingWithSideInputAndReluInNHWCTest) {
508   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
509                                kWithSideInput, "Relu");
510 }
511 
512 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest,  //
513                             TrainingInNHWCTest,              //
514                             TrainingWithReluInNHWCTest,      //
515                             TrainingWithSideInputAndReluInNHWCTest);
516 
517 using FusedBatchNormExTrainingDataTypes = ::testing::Types<Eigen::half>;
518 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpTrainingTest,
519                                FusedBatchNormExTrainingDataTypes);
520 #endif  // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
521 
522 // -------------------------------------------------------------------------- //
523 // FusedBatchNormEx[is_training=false].
524 
525 #if defined(GOOGLE_CUDA)
526 template <typename T>
527 using FusedBatchNormExOpInferenceTest =
528     FusedBatchNormExOpTestBase<T, float>;  // scale is always float
529 
530 TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest);
531 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceInNHWCTest)532 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceInNHWCTest) {
533   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
534                                kNoSideInput, "Identity");
535 }
536 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithReluInNHWCTest)537 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceWithReluInNHWCTest) {
538   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
539                                kNoSideInput, "Relu");
540 }
541 
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithSideInputAndReluInNHWCTest)542 TYPED_TEST_P(FusedBatchNormExOpInferenceTest,
543              InferenceWithSideInputAndReluInNHWCTest) {
544   this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
545                                kWithSideInput, "Relu");
546 }
547 
548 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest,  //
549                             InferenceInNHWCTest,              //
550                             InferenceWithReluInNHWCTest,      //
551                             InferenceWithSideInputAndReluInNHWCTest);
552 
553 using FusedBatchNormExInferenceDataTypes = ::testing::Types<Eigen::half, float>;
554 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpInferenceTest,
555                                FusedBatchNormExInferenceDataTypes);
556 #endif  // defined(GOOGLE_CUDA)
557 
558 // -------------------------------------------------------------------------- //
559 // Performance benchmarks are below.                                          //
560 // -------------------------------------------------------------------------- //
561 
562 using fp16 = Eigen::half;
563 using fp32 = float;
564 using SideInputAndActivation = std::pair<bool, string>;
565 
Identity()566 SideInputAndActivation Identity() { return {false, "Identity"}; }
Relu()567 SideInputAndActivation Relu() { return {false, "Relu"}; }
AddAndRelu()568 SideInputAndActivation AddAndRelu() { return {true, "Relu"}; }
569 
570 template <typename T>
FusedBatchNormEx(int n,int h,int w,int c,TensorFormat data_format,bool is_training,std::function<SideInputAndActivation ()> fn)571 static Graph* FusedBatchNormEx(int n, int h, int w, int c,
572                                TensorFormat data_format, bool is_training,
573                                std::function<SideInputAndActivation()> fn) {
574   Graph* g = new Graph(OpRegistry::Global());
575 
576   DataType dtype = DataTypeToEnum<T>::value;
577   Tensor x_t(dtype, data_format == FORMAT_NHWC ? TensorShape({n, h, w, c})
578                                                : TensorShape({n, c, h, w}));
579   x_t.flat<T>().setRandom();
580 
581   Tensor other_t(DT_FLOAT, TensorShape({c}));
582   other_t.flat<float>().setRandom();
583 
584   Node* x = test::graph::Constant(g, x_t, "x");
585   Node* other = test::graph::Constant(g, other_t, "other");
586   Node* empty = test::graph::Constant(g, Tensor(DT_FLOAT, {0}), "empty");
587 
588   int num_side_inputs = 0;
589   std::vector<NodeBuilder::NodeOut> side_inputs;
590 
591   SideInputAndActivation side_input_and_activation = fn();
592   bool has_side_input = side_input_and_activation.first;
593   string activation_mode = side_input_and_activation.second;
594 
595   if (has_side_input) {
596     num_side_inputs = 1;
597     side_inputs.push_back({x});
598   }
599 
600   Node* fused_batch_norm;
601   TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "_FusedBatchNormEx")
602                   .Input(x)
603                   .Input(other)                        // scale
604                   .Input(other)                        // offset
605                   .Input(is_training ? empty : other)  // mean
606                   .Input(is_training ? empty : other)  // variance
607                   .Input(side_inputs)                  // side_input
608                   .Attr("T", dtype)
609                   .Attr("U", DT_FLOAT)
610                   .Attr("epsilon", 0.001)
611                   .Attr("data_format", ToString(data_format))
612                   .Attr("activation_mode", activation_mode)
613                   .Attr("num_side_inputs", num_side_inputs)
614                   .Attr("is_training", is_training)
615                   .Finalize(g, &fused_batch_norm));
616 
617   return g;
618 }
619 
620 #define BM_CONCAT(a, b) a##_##b
621 
622 #define BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, A, DEVICE)          \
623   BM_CONCAT(BM_FusedBatchNorm##_##DEVICE##_##T##_##N##_##H##_##W##_##C, \
624             FORMAT##_##IS_TRAINING##_##A)
625 
626 #define BM_FusedBatchNorm(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION,    \
627                           DEVICE)                                            \
628   static void BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION,        \
629                       DEVICE)(::testing::benchmark::State & state) {         \
630     test::Benchmark(#DEVICE,                                                 \
631                     FusedBatchNormEx<T>(N, H, W, C, FORMAT_##FORMAT,         \
632                                         IS_TRAINING, {ACTIVATION}),          \
633                     /*old_benchmark_api*/ false)                             \
634         .Run(state);                                                         \
635     state.SetItemsProcessed(state.iterations() * N * H * W * C);             \
636   }                                                                          \
637   BENCHMARK(BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, DEVICE)) \
638       ->UseRealTime();
639 
640 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
641 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Identity, gpu);
642 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Relu, gpu);
643 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, AddAndRelu, gpu);
644 #endif  // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
645 
646 #if defined(GOOGLE_CUDA)
647 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Identity, gpu);
648 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Relu, gpu);
649 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, AddAndRelu, gpu);
650 
651 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Identity, gpu);
652 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Relu, gpu);
653 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, AddAndRelu, gpu);
654 
655 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Identity, gpu);
656 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Relu, gpu);
657 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, AddAndRelu, gpu);
658 
659 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Identity, gpu);
660 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Relu, gpu);
661 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, AddAndRelu, gpu);
662 #endif  // defined(GOOGLE_CUDA)
663 
664 }  // namespace tensorflow
665