1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/nn_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
34 #include "tensorflow/core/public/session.h"
35
36 #if GOOGLE_CUDA
37 #include "third_party/gpus/cudnn/cudnn.h"
38 #endif // GOOGLE_CUDA
39
40 namespace tensorflow {
41
42 template <typename T, typename U>
43 class FusedBatchNormExOpTestBase : public OpsTestBase {
44 public:
FusedBatchNormExOpTestBase()45 FusedBatchNormExOpTestBase() {
46 setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
47 }
48
49 protected:
50 struct FusedBatchNormOutputs {
51 Tensor y;
52 Tensor batch_mean;
53 Tensor batch_variance;
54 Tensor reserve_space_1;
55 Tensor reserve_space_2;
56 Tensor reserve_space_3;
57 };
58
59 struct FusedBatchNormGradOutputs {
60 Tensor y_backprop;
61 Tensor x_backprop;
62 Tensor scale_backprop;
63 Tensor offset_backprop;
64 Tensor reserve_space_4;
65 Tensor reserve_space_5;
66 };
67
68 using GraphRunner = std::function<void(
69 const Tensor& y_backprop, const Tensor& input_data,
70 const Tensor& scale_data, const Tensor& offset_data,
71 const Tensor& mean_data, const Tensor& var_data,
72 const Tensor& side_input_data, FusedBatchNormOutputs* forward,
73 FusedBatchNormGradOutputs* backward)>;
74
75 // Runs a Tensorflow graph defined by the root scope, and fetches the result
76 // of 'fetch' node into the outputs. Optional `add_nodes` parameter
77 // allows to define nodes directly using a NodeDef for the ops that are
78 // not supported by the C++ Api.
79 // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
80 // Add a base class for all FusedABC kernels and remove code duplication.
RunAndFetch(const tensorflow::Scope & root,const std::vector<string> & fetch,std::vector<Tensor> * outputs,bool allow_gpu_device,const std::vector<const NodeDef * > add_nodes={})81 void RunAndFetch(const tensorflow::Scope& root,
82 const std::vector<string>& fetch,
83 std::vector<Tensor>* outputs, bool allow_gpu_device,
84 const std::vector<const NodeDef*> add_nodes = {}) {
85 tensorflow::GraphDef graph;
86 TF_ASSERT_OK(root.ToGraphDef(&graph));
87
88 for (const NodeDef* add_node : add_nodes) {
89 *graph.add_node() = *add_node;
90 }
91
92 // We really want to make sure that graph executed exactly as we passed it
93 // to the session, so we disable various optimizations.
94 tensorflow::SessionOptions session_options;
95
96 // Disable common runtime constant folding.
97 session_options.config.mutable_graph_options()
98 ->mutable_optimizer_options()
99 ->set_opt_level(OptimizerOptions::L0);
100
101 // Disable Grappler optimizations for tests.
102 tensorflow::RewriterConfig* cfg =
103 session_options.config.mutable_graph_options()
104 ->mutable_rewrite_options();
105 cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
106 cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
107 cfg->set_remapping(tensorflow::RewriterConfig::OFF);
108
109 std::unique_ptr<tensorflow::Session> session(
110 tensorflow::NewSession(session_options));
111
112 std::vector<DeviceAttributes> available_devices;
113 TF_ASSERT_OK(session->ListDevices(&available_devices))
114 << "Failed to get available session devices";
115
116 // Check if session has an available GPU device.
117 const bool has_gpu_device =
__anon352582dc0102(const DeviceAttributes& device) 118 absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
119 return device.device_type() == DEVICE_GPU;
120 });
121
122 // Some of the `FusedABC` ops are implemented only for CPU, and in this test
123 // we don't want to compare GPU vs CPU numbers, so place all nodes on CPU in
124 // this case.
125 const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
126
127 const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
128 for (NodeDef& mutable_node : *graph.mutable_node()) {
129 mutable_node.set_device(device);
130 }
131
132 TF_ASSERT_OK(session->Create(graph));
133 TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
134 }
135
RunFusedBatchNorm(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)136 void RunFusedBatchNorm(const Tensor& y_backprop_data,
137 const Tensor& input_data, const Tensor& scale_data,
138 const Tensor& offset_data, const Tensor& mean_data,
139 const Tensor& var_data, const Tensor& side_input_data,
140 const TensorFormat data_format, bool is_training,
141 bool has_side_input, const string& activation_mode,
142 FusedBatchNormOutputs* forward,
143 FusedBatchNormGradOutputs* backward,
144 float epsilon = 0.1f) {
145 Scope root = tensorflow::Scope::NewRootScope();
146
147 Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
148 Input::Initializer(y_backprop_data));
149 Output input =
150 ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
151 Output scale =
152 ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
153 Output offset =
154 ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
155 Output mean =
156 ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
157 Output var =
158 ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
159 Output side_input = ops::Const(root.WithOpName("side_input"),
160 Input::Initializer(side_input_data));
161
162 ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
163 root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
164 ops::FusedBatchNormV3::IsTraining(is_training)
165 .Epsilon(epsilon)
166 .DataFormat(ToString(data_format)));
167
168 Output with_side_input;
169 if (has_side_input) {
170 with_side_input =
171 ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
172 } else {
173 with_side_input =
174 ops::Identity(root.WithOpName("with_side_input"), fwd.y);
175 }
176
177 Output activation;
178 if (activation_mode == "Relu") {
179 activation =
180 ops::Relu(root.WithOpName("with_activation"), with_side_input);
181 } else {
182 activation =
183 ops::Identity(root.WithOpName("with_activation"), with_side_input);
184 }
185
186 Output activation_grad;
187 if (activation_mode == "Relu") {
188 activation_grad = ops::internal::ReluGrad(
189 root.WithOpName("activation_grad"), y_backprop, activation);
190 } else {
191 activation_grad =
192 ops::Identity(root.WithOpName("activation_grad"), y_backprop);
193 }
194
195 ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
196 root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
197 fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
198 ops::FusedBatchNormGradV3::IsTraining(is_training)
199 .Epsilon(epsilon)
200 .DataFormat(ToString(data_format)));
201
202 std::vector<Tensor> out_tensors;
203 RunAndFetch(
204 root,
205 {"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
206 "fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
207 "activation_grad:0", "fused_batch_norm_grad:0",
208 "fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
209 &out_tensors, /*allow_gpu_device=*/true);
210
211 forward->y = out_tensors[0];
212 forward->batch_mean = out_tensors[1];
213 forward->batch_variance = out_tensors[2];
214 forward->reserve_space_1 = out_tensors[3];
215 forward->reserve_space_2 = out_tensors[4];
216 forward->reserve_space_3 = out_tensors[5];
217
218 backward->y_backprop = out_tensors[6];
219 backward->x_backprop = out_tensors[7];
220 backward->scale_backprop = out_tensors[8];
221 backward->offset_backprop = out_tensors[9];
222 }
223
RunFusedBatchNormEx(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)224 void RunFusedBatchNormEx(const Tensor& y_backprop_data,
225 const Tensor& input_data, const Tensor& scale_data,
226 const Tensor& offset_data, const Tensor& mean_data,
227 const Tensor& var_data,
228 const Tensor& side_input_data,
229 const TensorFormat data_format, bool is_training,
230 bool has_side_input, const string& activation_mode,
231 FusedBatchNormOutputs* forward,
232 FusedBatchNormGradOutputs* backward,
233 float epsilon = 0.1f) {
234 Scope root = tensorflow::Scope::NewRootScope();
235
236 DataType t_dtype = DataTypeToEnum<T>::v();
237 DataType u_dtype = DataTypeToEnum<U>::v();
238
239 Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
240 Input::Initializer(y_backprop_data));
241 Output input =
242 ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
243 Output scale =
244 ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
245 Output offset =
246 ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
247 Output mean =
248 ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
249 Output var =
250 ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
251 Output side_input = ops::Const(root.WithOpName("side_input"),
252 Input::Initializer(side_input_data));
253 Output empty =
254 ops::Const(root.WithOpName("empty"),
255 Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
256
257 int num_side_inputs = 0;
258 std::vector<NodeDefBuilder::NodeOut> side_inputs;
259
260 if (has_side_input) {
261 num_side_inputs = 1;
262 side_inputs.push_back({side_input.name(), 0, t_dtype});
263 }
264
265 NodeDef fused_batch_norm_ex;
266 TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_ex", "_FusedBatchNormEx")
267 .Input({input.name(), 0, t_dtype})
268 .Input({scale.name(), 0, u_dtype})
269 .Input({offset.name(), 0, u_dtype})
270 .Input({mean.name(), 0, u_dtype})
271 .Input({var.name(), 0, u_dtype})
272 .Input(side_inputs)
273 .Attr("T", t_dtype)
274 .Attr("U", u_dtype)
275 .Attr("data_format", ToString(data_format))
276 .Attr("epsilon", epsilon)
277 .Attr("activation_mode", activation_mode)
278 .Attr("num_side_inputs", num_side_inputs)
279 .Attr("is_training", is_training)
280 .Finalize(&fused_batch_norm_ex));
281
282 NodeDef activation_grad;
283 if (activation_mode == "Relu") {
284 TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
285 .Input({y_backprop.name(), 0, t_dtype})
286 .Input({fused_batch_norm_ex.name(), 0, t_dtype})
287 .Attr("T", t_dtype)
288 .Finalize(&activation_grad));
289 } else {
290 TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
291 .Input({y_backprop.name(), 0, t_dtype})
292 .Attr("T", t_dtype)
293 .Finalize(&activation_grad));
294 }
295
296 NodeDef fused_batch_norm_grad;
297 TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
298 .Input({activation_grad.name(), 0, t_dtype})
299 .Input({input.name(), 0, t_dtype})
300 .Input({scale.name(), 0, u_dtype})
301 .Input({fused_batch_norm_ex.name(), 3, u_dtype})
302 .Input({fused_batch_norm_ex.name(), 4, u_dtype})
303 .Input({fused_batch_norm_ex.name(), 5, u_dtype})
304 .Attr("T", t_dtype)
305 .Attr("U", u_dtype)
306 .Attr("data_format", ToString(data_format))
307 .Attr("epsilon", epsilon)
308 .Attr("is_training", is_training)
309 .Finalize(&fused_batch_norm_grad));
310
311 std::vector<Tensor> out_tensors;
312 RunAndFetch(
313 root,
314 {"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
315 "fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
316 "fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
317 "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
318 "fused_batch_norm_grad:2"},
319 &out_tensors,
320 /*allow_gpu_device=*/true,
321 {&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
322
323 forward->y = out_tensors[0];
324 forward->batch_mean = out_tensors[1];
325 forward->batch_variance = out_tensors[2];
326 forward->reserve_space_1 = out_tensors[3];
327 forward->reserve_space_2 = out_tensors[4];
328 forward->reserve_space_3 = out_tensors[5];
329
330 backward->y_backprop = out_tensors[6];
331 backward->x_backprop = out_tensors[7];
332 backward->scale_backprop = out_tensors[8];
333 backward->offset_backprop = out_tensors[9];
334 }
335
VerifyTensorsNear(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,const GraphRunner & run_default,const GraphRunner & run_fused)336 void VerifyTensorsNear(int batch, int height, int width, int channels,
337 TensorFormat data_format, bool is_training,
338 const GraphRunner& run_default,
339 const GraphRunner& run_fused) {
340 DataType t_dtype = DataTypeToEnum<T>::v();
341 DataType u_dtype = DataTypeToEnum<U>::v();
342
343 TensorShape input_shape =
344 data_format == FORMAT_NHWC
345 ? TensorShape({batch, height, width, channels})
346 : TensorShape({batch, channels, height, width});
347
348 Tensor input(t_dtype, input_shape);
349 input.flat<T>().setRandom();
350 input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
351
352 Tensor scale(u_dtype, {channels});
353 scale.flat<U>().setRandom();
354
355 Tensor offset(u_dtype, {channels});
356 offset.flat<U>().setRandom();
357
358 Tensor mean(u_dtype, {channels});
359 mean.flat<U>().setRandom();
360
361 Tensor var(u_dtype, {channels});
362 var.flat<U>().setRandom();
363
364 Tensor side_input(t_dtype, input_shape);
365 side_input.flat<T>().setRandom();
366 side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
367
368 Tensor y_backprop(t_dtype, input_shape);
369 y_backprop.flat<T>().setRandom();
370 y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
371
372 Tensor empty(u_dtype, {0});
373
374 FusedBatchNormOutputs fbn_forward;
375 FusedBatchNormOutputs fbn_ex_forward;
376
377 FusedBatchNormGradOutputs fbn_backward;
378 FusedBatchNormGradOutputs fbn_ex_backward;
379
380 run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
381 is_training ? empty : var, side_input, &fbn_forward,
382 &fbn_backward);
383
384 // Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
385 // make sure that fused kernel actually writes correct results to memory.
386 run_default(y_backprop, side_input, scale, offset,
387 is_training ? empty : mean, is_training ? empty : var, input,
388 &fbn_ex_forward, &fbn_ex_backward);
389
390 run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
391 is_training ? empty : var, side_input, &fbn_ex_forward,
392 &fbn_ex_backward);
393
394 std::vector<std::pair<Tensor, Tensor>> tensor_pairs;
395 if (is_training) {
396 tensor_pairs = {
397 {fbn_forward.y, fbn_ex_forward.y},
398 {fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
399 {fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
400 {fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
401 {fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
402 // NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
403 // because BatchNormEx with fused side input has different data in it,
404 // but we make sure that final gradients are the same.
405 {fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
406 {fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
407 {fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
408 {fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
409 };
410 } else {
411 tensor_pairs = {{fbn_forward.y, fbn_ex_forward.y}};
412 }
413
414 for (auto& pair : tensor_pairs) {
415 const Tensor& fbn = pair.first;
416 const Tensor& fbn_ex = pair.second;
417
418 ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
419 ASSERT_EQ(fbn.shape(), fbn_ex.shape());
420
421 test::ExpectClose(fbn, fbn_ex, 1e-2);
422 }
423 }
424
425 // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
426 // identical to FusedBatchNormExOp[fused_ops={SideInput, Activation}].
VerifyFusedBatchNormEx(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode)427 void VerifyFusedBatchNormEx(int batch, int height, int width, int channels,
428 TensorFormat data_format, bool is_training,
429 bool has_side_input,
430 const string& activation_mode) {
431 const GraphRunner run_default =
432 [&](const Tensor& y_backprop, const Tensor& input_data,
433 const Tensor& scale_data, const Tensor& offset_data,
434 const Tensor& mean_data, const Tensor& var_data,
435 const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
436 FusedBatchNormGradOutputs* bwd) {
437 this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
438 offset_data, mean_data, var_data,
439 side_input_data, data_format, is_training,
440 has_side_input, activation_mode, fwd, bwd);
441 };
442
443 const GraphRunner run_inference =
444 [&](const Tensor& y_backprop, const Tensor& input_data,
445 const Tensor& scale_data, const Tensor& offset_data,
446 const Tensor& mean_data, const Tensor& var_data,
447 const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
448 FusedBatchNormGradOutputs* bwd) {
449 this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
450 offset_data, mean_data, var_data,
451 side_input_data, data_format, is_training,
452 has_side_input, activation_mode, fwd, bwd);
453 };
454
455 VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
456 run_default, run_inference);
457 }
458 };
459
460 constexpr bool kInTraining = true; // is_training == true
461 constexpr bool kInInference = false; // is_training == false
462 constexpr bool kNoSideInput = false; // side_input == false
463 constexpr bool kWithSideInput = true; // side_input == true
464
465 // -------------------------------------------------------------------------- //
466 // FusedBatchNormEx[is_training=true].
467
468 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
469 template <typename T>
470 using FusedBatchNormExOpTrainingTest =
471 FusedBatchNormExOpTestBase<T, float>; // scale is always float
472
473 TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest);
474
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingInNHWCTest)475 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingInNHWCTest) {
476 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
477 kNoSideInput, "Identity");
478 }
479
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithReluInNHWCTest)480 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingWithReluInNHWCTest) {
481 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
482 kNoSideInput, "Relu");
483 }
484
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithSideInputAndReluInNHWCTest)485 TYPED_TEST_P(FusedBatchNormExOpTrainingTest,
486 TrainingWithSideInputAndReluInNHWCTest) {
487 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
488 kWithSideInput, "Relu");
489 }
490
491 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest, //
492 TrainingInNHWCTest, //
493 TrainingWithReluInNHWCTest, //
494 TrainingWithSideInputAndReluInNHWCTest);
495
496 using FusedBatchNormExTrainingDataTypes = ::testing::Types<Eigen::half>;
497 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpTrainingTest,
498 FusedBatchNormExTrainingDataTypes);
499 #endif // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
500
501 // -------------------------------------------------------------------------- //
502 // FusedBatchNormEx[is_training=false].
503
504 #if defined(GOOGLE_CUDA)
505 template <typename T>
506 using FusedBatchNormExOpInferenceTest =
507 FusedBatchNormExOpTestBase<T, float>; // scale is always float
508
509 TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest);
510
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceInNHWCTest)511 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceInNHWCTest) {
512 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
513 kNoSideInput, "Identity");
514 }
515
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithReluInNHWCTest)516 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceWithReluInNHWCTest) {
517 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
518 kNoSideInput, "Relu");
519 }
520
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithSideInputAndReluInNHWCTest)521 TYPED_TEST_P(FusedBatchNormExOpInferenceTest,
522 InferenceWithSideInputAndReluInNHWCTest) {
523 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
524 kWithSideInput, "Relu");
525 }
526
527 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest, //
528 InferenceInNHWCTest, //
529 InferenceWithReluInNHWCTest, //
530 InferenceWithSideInputAndReluInNHWCTest);
531
532 using FusedBatchNormExInferenceDataTypes = ::testing::Types<Eigen::half, float>;
533 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpInferenceTest,
534 FusedBatchNormExInferenceDataTypes);
535 #endif // defined(GOOGLE_CUDA)
536
537 // -------------------------------------------------------------------------- //
538 // Performance benchmarks are below. //
539 // -------------------------------------------------------------------------- //
540
541 using fp16 = Eigen::half;
542 using fp32 = float;
543 using SideInputAndActivation = std::pair<bool, string>;
544
Identity()545 SideInputAndActivation Identity() { return {false, "Identity"}; }
Relu()546 SideInputAndActivation Relu() { return {false, "Relu"}; }
AddAndRelu()547 SideInputAndActivation AddAndRelu() { return {true, "Relu"}; }
548
549 template <typename T>
FusedBatchNormEx(int n,int h,int w,int c,TensorFormat data_format,bool is_training,std::function<SideInputAndActivation ()> fn)550 static Graph* FusedBatchNormEx(int n, int h, int w, int c,
551 TensorFormat data_format, bool is_training,
552 std::function<SideInputAndActivation()> fn) {
553 Graph* g = new Graph(OpRegistry::Global());
554
555 DataType dtype = DataTypeToEnum<T>::value;
556 Tensor x_t(dtype, data_format == FORMAT_NHWC ? TensorShape({n, h, w, c})
557 : TensorShape({n, c, h, w}));
558 x_t.flat<T>().setRandom();
559
560 Tensor other_t(DT_FLOAT, TensorShape({c}));
561 other_t.flat<float>().setRandom();
562
563 Node* x = test::graph::Constant(g, x_t, "x");
564 Node* other = test::graph::Constant(g, other_t, "other");
565 Node* empty = test::graph::Constant(g, Tensor(DT_FLOAT, {0}), "empty");
566
567 int num_side_inputs = 0;
568 std::vector<NodeBuilder::NodeOut> side_inputs;
569
570 SideInputAndActivation side_input_and_activation = fn();
571 bool has_side_input = side_input_and_activation.first;
572 string activation_mode = side_input_and_activation.second;
573
574 if (has_side_input) {
575 num_side_inputs = 1;
576 side_inputs.push_back({x});
577 }
578
579 Node* fused_batch_norm;
580 TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "_FusedBatchNormEx")
581 .Input(x)
582 .Input(other) // scale
583 .Input(other) // offset
584 .Input(is_training ? empty : other) // mean
585 .Input(is_training ? empty : other) // variance
586 .Input(side_inputs) // side_input
587 .Attr("T", dtype)
588 .Attr("U", DT_FLOAT)
589 .Attr("epsilon", 0.001)
590 .Attr("data_format", ToString(data_format))
591 .Attr("activation_mode", activation_mode)
592 .Attr("num_side_inputs", num_side_inputs)
593 .Attr("is_training", is_training)
594 .Finalize(g, &fused_batch_norm));
595
596 return g;
597 }
598
599 #define BM_CONCAT(a, b) a##_##b
600
601 #define BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, A, DEVICE) \
602 BM_CONCAT(BM_FusedBatchNorm##_##DEVICE##_##T##_##N##_##H##_##W##_##C, \
603 FORMAT##_##IS_TRAINING##_##A)
604
605 #define BM_FusedBatchNorm(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, \
606 DEVICE) \
607 static void BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, \
608 DEVICE)(::testing::benchmark::State & state) { \
609 test::Benchmark(#DEVICE, \
610 FusedBatchNormEx<T>(N, H, W, C, FORMAT_##FORMAT, \
611 IS_TRAINING, {ACTIVATION}), \
612 /*old_benchmark_api*/ false) \
613 .Run(state); \
614 state.SetItemsProcessed(state.iterations() * N * H * W * C); \
615 } \
616 BENCHMARK(BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, DEVICE)) \
617 ->UseRealTime();
618
619 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
620 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Identity, gpu);
621 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Relu, gpu);
622 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, AddAndRelu, gpu);
623 #endif // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
624
625 #if defined(GOOGLE_CUDA)
626 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Identity, gpu);
627 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Relu, gpu);
628 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, AddAndRelu, gpu);
629
630 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Identity, gpu);
631 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Relu, gpu);
632 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, AddAndRelu, gpu);
633
634 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Identity, gpu);
635 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Relu, gpu);
636 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, AddAndRelu, gpu);
637
638 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Identity, gpu);
639 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Relu, gpu);
640 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, AddAndRelu, gpu);
641 #endif // defined(GOOGLE_CUDA)
642
643 } // namespace tensorflow
644