1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/algorithm/container.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/ops/const_op.h"
19 #include "tensorflow/cc/ops/nn_ops.h"
20 #include "tensorflow/cc/ops/nn_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
34 #include "tensorflow/core/public/session.h"
35
36 #if GOOGLE_CUDA
37 #include "third_party/gpus/cudnn/cudnn.h"
38 #endif // GOOGLE_CUDA
39
40 namespace tensorflow {
41
42 template <typename T, typename U>
43 class FusedBatchNormExOpTestBase : public OpsTestBase {
44 public:
FusedBatchNormExOpTestBase()45 FusedBatchNormExOpTestBase() {
46 setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
47 }
48
49 protected:
50 struct FusedBatchNormOutputs {
51 Tensor y;
52 Tensor batch_mean;
53 Tensor batch_variance;
54 Tensor reserve_space_1;
55 Tensor reserve_space_2;
56 Tensor reserve_space_3;
57 };
58
59 struct FusedBatchNormGradOutputs {
60 Tensor y_backprop;
61 Tensor x_backprop;
62 Tensor scale_backprop;
63 Tensor offset_backprop;
64 Tensor reserve_space_4;
65 Tensor reserve_space_5;
66 };
67
68 using GraphRunner = std::function<void(
69 const Tensor& y_backprop, const Tensor& input_data,
70 const Tensor& scale_data, const Tensor& offset_data,
71 const Tensor& mean_data, const Tensor& var_data,
72 const Tensor& side_input_data, FusedBatchNormOutputs* forward,
73 FusedBatchNormGradOutputs* backward)>;
74
75 // Runs a Tensorflow graph defined by the root scope, and fetches the result
76 // of 'fetch' node into the outputs. Optional `add_nodes` parameter
77 // allows to define nodes directly using a NodeDef for the ops that are
78 // not supported by the C++ Api.
79 // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
80 // Add a base class for all FusedABC kernels and remove code duplication.
RunAndFetch(const tensorflow::Scope & root,const std::vector<string> & fetch,std::vector<Tensor> * outputs,bool allow_gpu_device,const std::vector<const NodeDef * > add_nodes={})81 void RunAndFetch(const tensorflow::Scope& root,
82 const std::vector<string>& fetch,
83 std::vector<Tensor>* outputs, bool allow_gpu_device,
84 const std::vector<const NodeDef*> add_nodes = {}) {
85 tensorflow::GraphDef graph;
86 TF_ASSERT_OK(root.ToGraphDef(&graph));
87
88 for (const NodeDef* add_node : add_nodes) {
89 *graph.add_node() = *add_node;
90 }
91
92 // We really want to make sure that graph executed exactly as we passed it
93 // to the session, so we disable various optimizations.
94 tensorflow::SessionOptions session_options;
95
96 // Disable common runtime constant folding.
97 session_options.config.mutable_graph_options()
98 ->mutable_optimizer_options()
99 ->set_opt_level(OptimizerOptions::L0);
100
101 // Disable Grappler optimizations for tests.
102 tensorflow::RewriterConfig* cfg =
103 session_options.config.mutable_graph_options()
104 ->mutable_rewrite_options();
105 cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
106 cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
107 cfg->set_remapping(tensorflow::RewriterConfig::OFF);
108
109 std::unique_ptr<tensorflow::Session> session(
110 tensorflow::NewSession(session_options));
111
112 std::vector<DeviceAttributes> available_devices;
113 TF_ASSERT_OK(session->ListDevices(&available_devices))
114 << "Failed to get available session devices";
115
116 // Check if session has an available GPU device.
117 const bool has_gpu_device =
__anoncb8c21e60102(const DeviceAttributes& device) 118 absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
119 return device.device_type() == DEVICE_GPU;
120 });
121
122 // Some of the `FusedABC` ops are implemented only for CPU, and in this test
123 // we don't want to compare GPU vs CPU numbers, so place all nodes on CPU in
124 // this case.
125 const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
126
127 const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
128 for (NodeDef& mutable_node : *graph.mutable_node()) {
129 mutable_node.set_device(device);
130 }
131
132 TF_ASSERT_OK(session->Create(graph));
133 TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
134 }
135
RunFusedBatchNorm(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)136 void RunFusedBatchNorm(const Tensor& y_backprop_data,
137 const Tensor& input_data, const Tensor& scale_data,
138 const Tensor& offset_data, const Tensor& mean_data,
139 const Tensor& var_data, const Tensor& side_input_data,
140 const TensorFormat data_format, bool is_training,
141 bool has_side_input, const string& activation_mode,
142 FusedBatchNormOutputs* forward,
143 FusedBatchNormGradOutputs* backward,
144 float epsilon = 0.1f) {
145 Scope root = tensorflow::Scope::NewRootScope();
146
147 Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
148 Input::Initializer(y_backprop_data));
149 Output input =
150 ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
151 Output scale =
152 ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
153 Output offset =
154 ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
155 Output mean =
156 ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
157 Output var =
158 ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
159 Output side_input = ops::Const(root.WithOpName("side_input"),
160 Input::Initializer(side_input_data));
161
162 ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
163 root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
164 ops::FusedBatchNormV3::IsTraining(is_training)
165 .Epsilon(epsilon)
166 .DataFormat(ToString(data_format)));
167
168 Output with_side_input;
169 if (has_side_input) {
170 with_side_input =
171 ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
172 } else {
173 with_side_input =
174 ops::Identity(root.WithOpName("with_side_input"), fwd.y);
175 }
176
177 Output activation;
178 if (activation_mode == "Relu") {
179 activation =
180 ops::Relu(root.WithOpName("with_activation"), with_side_input);
181 } else {
182 activation =
183 ops::Identity(root.WithOpName("with_activation"), with_side_input);
184 }
185
186 Output activation_grad;
187 if (activation_mode == "Relu") {
188 activation_grad = ops::internal::ReluGrad(
189 root.WithOpName("activation_grad"), y_backprop, activation);
190 } else {
191 activation_grad =
192 ops::Identity(root.WithOpName("activation_grad"), y_backprop);
193 }
194
195 ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
196 root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
197 fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
198 ops::FusedBatchNormGradV3::IsTraining(is_training)
199 .Epsilon(epsilon)
200 .DataFormat(ToString(data_format)));
201
202 std::vector<Tensor> out_tensors;
203 RunAndFetch(
204 root,
205 {"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
206 "fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
207 "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
208 "fused_batch_norm_grad:2"},
209 &out_tensors, /*allow_gpu_device=*/true);
210
211 forward->y = out_tensors[0];
212 forward->batch_mean = out_tensors[1];
213 forward->batch_variance = out_tensors[2];
214 forward->reserve_space_1 = out_tensors[3];
215 forward->reserve_space_2 = out_tensors[4];
216 forward->reserve_space_3 = out_tensors[5];
217
218 backward->x_backprop = out_tensors[6];
219 backward->scale_backprop = out_tensors[7];
220 backward->offset_backprop = out_tensors[8];
221 }
222
RunFusedBatchNormEx(const Tensor & y_backprop_data,const Tensor & input_data,const Tensor & scale_data,const Tensor & offset_data,const Tensor & mean_data,const Tensor & var_data,const Tensor & side_input_data,const TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode,FusedBatchNormOutputs * forward,FusedBatchNormGradOutputs * backward,float epsilon=0.1f)223 void RunFusedBatchNormEx(const Tensor& y_backprop_data,
224 const Tensor& input_data, const Tensor& scale_data,
225 const Tensor& offset_data, const Tensor& mean_data,
226 const Tensor& var_data,
227 const Tensor& side_input_data,
228 const TensorFormat data_format, bool is_training,
229 bool has_side_input, const string& activation_mode,
230 FusedBatchNormOutputs* forward,
231 FusedBatchNormGradOutputs* backward,
232 float epsilon = 0.1f) {
233 Scope root = tensorflow::Scope::NewRootScope();
234
235 DataType t_dtype = DataTypeToEnum<T>::v();
236 DataType u_dtype = DataTypeToEnum<U>::v();
237
238 Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
239 Input::Initializer(y_backprop_data));
240 Output input =
241 ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
242 Output scale =
243 ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
244 Output offset =
245 ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
246 Output mean =
247 ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
248 Output var =
249 ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
250 Output side_input = ops::Const(root.WithOpName("side_input"),
251 Input::Initializer(side_input_data));
252 Output empty =
253 ops::Const(root.WithOpName("empty"),
254 Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
255
256 int num_side_inputs = 0;
257 std::vector<NodeDefBuilder::NodeOut> side_inputs;
258
259 if (has_side_input) {
260 num_side_inputs = 1;
261 side_inputs.push_back({side_input.name(), 0, t_dtype});
262 }
263
264 NodeDef fused_batch_norm_ex;
265 TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_ex", "_FusedBatchNormEx")
266 .Input({input.name(), 0, t_dtype})
267 .Input({scale.name(), 0, u_dtype})
268 .Input({offset.name(), 0, u_dtype})
269 .Input({mean.name(), 0, u_dtype})
270 .Input({var.name(), 0, u_dtype})
271 .Input(side_inputs)
272 .Attr("T", t_dtype)
273 .Attr("U", u_dtype)
274 .Attr("data_format", ToString(data_format))
275 .Attr("epsilon", epsilon)
276 .Attr("activation_mode", activation_mode)
277 .Attr("num_side_inputs", num_side_inputs)
278 .Attr("is_training", is_training)
279 .Finalize(&fused_batch_norm_ex));
280
281 NodeDef fused_batch_norm_grad;
282 NodeDef activation_grad;
283 std::vector<Tensor> out_tensors;
284 std::vector<const NodeDef*> add_nodes;
285 if (is_training) {
286 TF_EXPECT_OK(
287 NodeDefBuilder("fused_batch_norm_grad", "_FusedBatchNormGradEx")
288 .Input({y_backprop.name(), 0, t_dtype})
289 .Input({input.name(), 0, t_dtype})
290 .Input({scale.name(), 0, u_dtype})
291 .Input({fused_batch_norm_ex.name(), 3, u_dtype})
292 .Input({fused_batch_norm_ex.name(), 4, u_dtype})
293 .Input({fused_batch_norm_ex.name(), 5, u_dtype})
294 .Input({offset.name(), 0, u_dtype})
295 .Input({fused_batch_norm_ex.name(), 0, t_dtype})
296 .Attr("T", t_dtype)
297 .Attr("U", u_dtype)
298 .Attr("data_format", ToString(data_format))
299 .Attr("epsilon", epsilon)
300 .Attr("activation_mode", activation_mode)
301 .Attr("num_side_inputs", num_side_inputs)
302 .Attr("is_training", is_training)
303 .Finalize(&fused_batch_norm_grad));
304 add_nodes = {&fused_batch_norm_ex, &fused_batch_norm_grad};
305 } else {
306 if (activation_mode == "Relu") {
307 TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
308 .Input({y_backprop.name(), 0, t_dtype})
309 .Input({fused_batch_norm_ex.name(), 0, t_dtype})
310 .Attr("T", t_dtype)
311 .Finalize(&activation_grad));
312 } else {
313 TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
314 .Input({y_backprop.name(), 0, t_dtype})
315 .Attr("T", t_dtype)
316 .Finalize(&activation_grad));
317 }
318 TF_EXPECT_OK(
319 NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
320 .Input({activation_grad.name(), 0, t_dtype})
321 .Input({input.name(), 0, t_dtype})
322 .Input({scale.name(), 0, u_dtype})
323 .Input({fused_batch_norm_ex.name(), 3, u_dtype})
324 .Input({fused_batch_norm_ex.name(), 4, u_dtype})
325 .Input({fused_batch_norm_ex.name(), 5, u_dtype})
326 .Attr("T", t_dtype)
327 .Attr("U", u_dtype)
328 .Attr("data_format", ToString(data_format))
329 .Attr("epsilon", epsilon)
330 .Attr("is_training", is_training)
331 .Finalize(&fused_batch_norm_grad));
332 add_nodes = {&fused_batch_norm_ex, &activation_grad,
333 &fused_batch_norm_grad};
334 }
335
336 RunAndFetch(root,
337 {"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
338 "fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
339 "fused_batch_norm_ex:4", "fused_batch_norm_ex:5",
340 "fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
341 "fused_batch_norm_grad:2"},
342 &out_tensors,
343 /*allow_gpu_device=*/true, add_nodes);
344
345 forward->y = out_tensors[0];
346 forward->batch_mean = out_tensors[1];
347 forward->batch_variance = out_tensors[2];
348 forward->reserve_space_1 = out_tensors[3];
349 forward->reserve_space_2 = out_tensors[4];
350 forward->reserve_space_3 = out_tensors[5];
351
352 backward->x_backprop = out_tensors[6];
353 backward->scale_backprop = out_tensors[7];
354 backward->offset_backprop = out_tensors[8];
355 }
356
VerifyTensorsNear(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,const GraphRunner & run_default,const GraphRunner & run_fused)357 void VerifyTensorsNear(int batch, int height, int width, int channels,
358 TensorFormat data_format, bool is_training,
359 const GraphRunner& run_default,
360 const GraphRunner& run_fused) {
361 DataType t_dtype = DataTypeToEnum<T>::v();
362 DataType u_dtype = DataTypeToEnum<U>::v();
363
364 TensorShape input_shape =
365 data_format == FORMAT_NHWC
366 ? TensorShape({batch, height, width, channels})
367 : TensorShape({batch, channels, height, width});
368
369 Tensor input(t_dtype, input_shape);
370 input.flat<T>().setRandom();
371 input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
372
373 Tensor scale(u_dtype, {channels});
374 scale.flat<U>().setRandom();
375
376 Tensor offset(u_dtype, {channels});
377 offset.flat<U>().setRandom();
378
379 Tensor mean(u_dtype, {channels});
380 mean.flat<U>().setRandom();
381
382 Tensor var(u_dtype, {channels});
383 var.flat<U>().setRandom();
384
385 Tensor side_input(t_dtype, input_shape);
386 side_input.flat<T>().setRandom();
387 side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
388
389 Tensor y_backprop(t_dtype, input_shape);
390 y_backprop.flat<T>().setRandom();
391 y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
392
393 Tensor empty(u_dtype, {0});
394
395 FusedBatchNormOutputs fbn_forward;
396 FusedBatchNormOutputs fbn_ex_forward;
397
398 FusedBatchNormGradOutputs fbn_backward;
399 FusedBatchNormGradOutputs fbn_ex_backward;
400
401 run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
402 is_training ? empty : var, side_input, &fbn_forward,
403 &fbn_backward);
404
405 // Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
406 // make sure that fused kernel actually writes correct results to memory.
407 run_default(y_backprop, side_input, scale, offset,
408 is_training ? empty : mean, is_training ? empty : var, input,
409 &fbn_ex_forward, &fbn_ex_backward);
410
411 run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
412 is_training ? empty : var, side_input, &fbn_ex_forward,
413 &fbn_ex_backward);
414
415 std::vector<std::pair<Tensor, Tensor>> tensor_pairs;
416 if (is_training) {
417 tensor_pairs = {
418 {fbn_forward.y, fbn_ex_forward.y},
419 {fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
420 {fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
421 {fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
422 {fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
423 // NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
424 // because BatchNormEx with fused side input has different data in it,
425 // but we make sure that final gradients are the same.
426 {fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
427 {fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
428 {fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
429 {fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
430 };
431 } else {
432 tensor_pairs = {{fbn_forward.y, fbn_ex_forward.y}};
433 }
434
435 for (auto& pair : tensor_pairs) {
436 const Tensor& fbn = pair.first;
437 const Tensor& fbn_ex = pair.second;
438
439 ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
440 ASSERT_EQ(fbn.shape(), fbn_ex.shape());
441
442 test::ExpectClose(fbn, fbn_ex, 1e-2);
443 }
444 }
445
446 // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
447 // identical to FusedBatchNormExOp[fused_ops={SideInput, Activation}].
VerifyFusedBatchNormEx(int batch,int height,int width,int channels,TensorFormat data_format,bool is_training,bool has_side_input,const string & activation_mode)448 void VerifyFusedBatchNormEx(int batch, int height, int width, int channels,
449 TensorFormat data_format, bool is_training,
450 bool has_side_input,
451 const string& activation_mode) {
452 const GraphRunner run_default =
453 [&](const Tensor& y_backprop, const Tensor& input_data,
454 const Tensor& scale_data, const Tensor& offset_data,
455 const Tensor& mean_data, const Tensor& var_data,
456 const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
457 FusedBatchNormGradOutputs* bwd) {
458 this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
459 offset_data, mean_data, var_data,
460 side_input_data, data_format, is_training,
461 has_side_input, activation_mode, fwd, bwd);
462 };
463
464 const GraphRunner run_inference =
465 [&](const Tensor& y_backprop, const Tensor& input_data,
466 const Tensor& scale_data, const Tensor& offset_data,
467 const Tensor& mean_data, const Tensor& var_data,
468 const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
469 FusedBatchNormGradOutputs* bwd) {
470 this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
471 offset_data, mean_data, var_data,
472 side_input_data, data_format, is_training,
473 has_side_input, activation_mode, fwd, bwd);
474 };
475
476 VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
477 run_default, run_inference);
478 }
479 };
480
481 constexpr bool kInTraining = true; // is_training == true
482 constexpr bool kInInference = false; // is_training == false
483 constexpr bool kNoSideInput = false; // side_input == false
484 constexpr bool kWithSideInput = true; // side_input == true
485
486 // -------------------------------------------------------------------------- //
487 // FusedBatchNormEx[is_training=true].
488
489 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
490 template <typename T>
491 using FusedBatchNormExOpTrainingTest =
492 FusedBatchNormExOpTestBase<T, float>; // scale is always float
493
494 TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest);
495
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingInNHWCTest)496 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingInNHWCTest) {
497 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
498 kNoSideInput, "Identity");
499 }
500
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithReluInNHWCTest)501 TYPED_TEST_P(FusedBatchNormExOpTrainingTest, TrainingWithReluInNHWCTest) {
502 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
503 kNoSideInput, "Relu");
504 }
505
TYPED_TEST_P(FusedBatchNormExOpTrainingTest,TrainingWithSideInputAndReluInNHWCTest)506 TYPED_TEST_P(FusedBatchNormExOpTrainingTest,
507 TrainingWithSideInputAndReluInNHWCTest) {
508 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
509 kWithSideInput, "Relu");
510 }
511
512 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpTrainingTest, //
513 TrainingInNHWCTest, //
514 TrainingWithReluInNHWCTest, //
515 TrainingWithSideInputAndReluInNHWCTest);
516
517 using FusedBatchNormExTrainingDataTypes = ::testing::Types<Eigen::half>;
518 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpTrainingTest,
519 FusedBatchNormExTrainingDataTypes);
520 #endif // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
521
522 // -------------------------------------------------------------------------- //
523 // FusedBatchNormEx[is_training=false].
524
525 #if defined(GOOGLE_CUDA)
526 template <typename T>
527 using FusedBatchNormExOpInferenceTest =
528 FusedBatchNormExOpTestBase<T, float>; // scale is always float
529
530 TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest);
531
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceInNHWCTest)532 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceInNHWCTest) {
533 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
534 kNoSideInput, "Identity");
535 }
536
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithReluInNHWCTest)537 TYPED_TEST_P(FusedBatchNormExOpInferenceTest, InferenceWithReluInNHWCTest) {
538 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
539 kNoSideInput, "Relu");
540 }
541
TYPED_TEST_P(FusedBatchNormExOpInferenceTest,InferenceWithSideInputAndReluInNHWCTest)542 TYPED_TEST_P(FusedBatchNormExOpInferenceTest,
543 InferenceWithSideInputAndReluInNHWCTest) {
544 this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInInference,
545 kWithSideInput, "Relu");
546 }
547
548 REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormExOpInferenceTest, //
549 InferenceInNHWCTest, //
550 InferenceWithReluInNHWCTest, //
551 InferenceWithSideInputAndReluInNHWCTest);
552
553 using FusedBatchNormExInferenceDataTypes = ::testing::Types<Eigen::half, float>;
554 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormExOpInferenceTest,
555 FusedBatchNormExInferenceDataTypes);
556 #endif // defined(GOOGLE_CUDA)
557
558 // -------------------------------------------------------------------------- //
559 // Performance benchmarks are below. //
560 // -------------------------------------------------------------------------- //
561
562 using fp16 = Eigen::half;
563 using fp32 = float;
564 using SideInputAndActivation = std::pair<bool, string>;
565
Identity()566 SideInputAndActivation Identity() { return {false, "Identity"}; }
Relu()567 SideInputAndActivation Relu() { return {false, "Relu"}; }
AddAndRelu()568 SideInputAndActivation AddAndRelu() { return {true, "Relu"}; }
569
570 template <typename T>
FusedBatchNormEx(int n,int h,int w,int c,TensorFormat data_format,bool is_training,std::function<SideInputAndActivation ()> fn)571 static Graph* FusedBatchNormEx(int n, int h, int w, int c,
572 TensorFormat data_format, bool is_training,
573 std::function<SideInputAndActivation()> fn) {
574 Graph* g = new Graph(OpRegistry::Global());
575
576 DataType dtype = DataTypeToEnum<T>::value;
577 Tensor x_t(dtype, data_format == FORMAT_NHWC ? TensorShape({n, h, w, c})
578 : TensorShape({n, c, h, w}));
579 x_t.flat<T>().setRandom();
580
581 Tensor other_t(DT_FLOAT, TensorShape({c}));
582 other_t.flat<float>().setRandom();
583
584 Node* x = test::graph::Constant(g, x_t, "x");
585 Node* other = test::graph::Constant(g, other_t, "other");
586 Node* empty = test::graph::Constant(g, Tensor(DT_FLOAT, {0}), "empty");
587
588 int num_side_inputs = 0;
589 std::vector<NodeBuilder::NodeOut> side_inputs;
590
591 SideInputAndActivation side_input_and_activation = fn();
592 bool has_side_input = side_input_and_activation.first;
593 string activation_mode = side_input_and_activation.second;
594
595 if (has_side_input) {
596 num_side_inputs = 1;
597 side_inputs.push_back({x});
598 }
599
600 Node* fused_batch_norm;
601 TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "_FusedBatchNormEx")
602 .Input(x)
603 .Input(other) // scale
604 .Input(other) // offset
605 .Input(is_training ? empty : other) // mean
606 .Input(is_training ? empty : other) // variance
607 .Input(side_inputs) // side_input
608 .Attr("T", dtype)
609 .Attr("U", DT_FLOAT)
610 .Attr("epsilon", 0.001)
611 .Attr("data_format", ToString(data_format))
612 .Attr("activation_mode", activation_mode)
613 .Attr("num_side_inputs", num_side_inputs)
614 .Attr("is_training", is_training)
615 .Finalize(g, &fused_batch_norm));
616
617 return g;
618 }
619
620 #define BM_CONCAT(a, b) a##_##b
621
622 #define BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, A, DEVICE) \
623 BM_CONCAT(BM_FusedBatchNorm##_##DEVICE##_##T##_##N##_##H##_##W##_##C, \
624 FORMAT##_##IS_TRAINING##_##A)
625
626 #define BM_FusedBatchNorm(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, \
627 DEVICE) \
628 static void BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, \
629 DEVICE)(::testing::benchmark::State & state) { \
630 test::Benchmark(#DEVICE, \
631 FusedBatchNormEx<T>(N, H, W, C, FORMAT_##FORMAT, \
632 IS_TRAINING, {ACTIVATION}), \
633 /*old_benchmark_api*/ false) \
634 .Run(state); \
635 state.SetItemsProcessed(state.iterations() * N * H * W * C); \
636 } \
637 BENCHMARK(BM_NAME(N, H, W, C, T, FORMAT, IS_TRAINING, ACTIVATION, DEVICE)) \
638 ->UseRealTime();
639
640 #if defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
641 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Identity, gpu);
642 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, Relu, gpu);
643 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, true, AddAndRelu, gpu);
644 #endif // defined(GOOGLE_CUDA) && (CUDNN_VERSION >= 7402)
645
646 #if defined(GOOGLE_CUDA)
647 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Identity, gpu);
648 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, Relu, gpu);
649 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NHWC, false, AddAndRelu, gpu);
650
651 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Identity, gpu);
652 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, Relu, gpu);
653 BM_FusedBatchNorm(64, 14, 14, 256, fp16, NCHW, false, AddAndRelu, gpu);
654
655 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Identity, gpu);
656 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, Relu, gpu);
657 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NHWC, false, AddAndRelu, gpu);
658
659 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Identity, gpu);
660 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, Relu, gpu);
661 BM_FusedBatchNorm(64, 14, 14, 256, fp32, NCHW, false, AddAndRelu, gpu);
662 #endif // defined(GOOGLE_CUDA)
663
664 } // namespace tensorflow
665