• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with other kernels baked into the
17 // processing, to optimize latency and memory usage:
18 //  - Conv2D + BiasAdd + <Activation>
19 //  - Conv2D + FusedBatchNorm + <Activation>
20 //
21 // Activation: Relu, Relu6, Elu, etc...
22 //
23 // Kernels for convolutions fused with image transformations (resize and mirror
24 // padding) defined in `conv_ops_fused_image_transform.cc`.
25 //
26 // For the CPU device we implement fusion with an Eigen tensor contraction
27 // output kernel. For the GPU device we rely on CuDNN primitives.
28 //
29 // NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
30 
31 #ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
32 #define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
33 
34 #define USE_EIGEN_TENSOR
35 #define EIGEN_USE_THREADS
36 
37 #if GOOGLE_CUDA
38 #define EIGEN_USE_GPU
39 #endif  // GOOGLE_CUDA
40 
41 #include <string>
42 #include <vector>
43 
44 #include "absl/strings/str_cat.h"
45 #include "absl/strings/str_join.h"
46 #include "absl/strings/substitute.h"
47 #include "tensorflow/core/framework/bounds_check.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/register_types.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/kernels/conv_2d.h"
53 #include "tensorflow/core/kernels/conv_ops.h"
54 #include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
55 #include "tensorflow/core/kernels/ops_util.h"
56 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
57 #include "tensorflow/core/util/tensor_format.h"
58 #include "tensorflow/core/util/use_cudnn.h"
59 
60 #if GOOGLE_CUDA
61 #include "third_party/gpus/cudnn/cudnn.h"
62 #include "tensorflow/core/kernels/conv_ops_gpu.h"
63 #include "tensorflow/core/platform/stream_executor.h"
64 #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
65 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
66 #include "tensorflow/core/util/proto/proto_utils.h"
67 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
68 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
69 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
70 #endif  // GOOGLE_CUDA
71 
72 namespace tensorflow {
73 
74 class AutotuneResult;
75 
76 typedef Eigen::ThreadPoolDevice CPUDevice;
77 typedef Eigen::GpuDevice GPUDevice;
78 
79 template <typename Device, typename T>
80 struct LaunchFusedConv2DOp {
81   void operator()(OpKernelContext* context, bool use_cudnn,
82                   bool cudnn_use_autotune, const Tensor& input,
83                   const Tensor& filter, FusedComputationType fusion,
84                   const FusedComputationArgs& fusion_args,
85                   const Conv2DParameters& params,
86                   const Conv2DDimensions& dimensions, Tensor* output);
87 };
88 
89 // This is CPU-only implementation that uses Eigen contraction output kernels.
90 //
91 // Dispatch 2D convolution to the appropriate primitive operation:
92 //   (1) MatMul for the case of 1x1 convolution.
93 //   (2) MatMul for the case when filter size equals to the input size.
94 //   (3) General spatial 2D convolution for all other cases.
95 template <typename T>
96 class LaunchFusedConv2DWithOutputKernel {
97  public:
LaunchFusedConv2DWithOutputKernel(int row_stride,int col_stride,int row_dilation,int col_dilation,Padding padding,const std::vector<int64> & explicit_paddings)98   LaunchFusedConv2DWithOutputKernel(int row_stride, int col_stride,      //
99                                     int row_dilation, int col_dilation,  //
100                                     Padding padding,
101                                     const std::vector<int64>& explicit_paddings)
102       : row_stride_(row_stride),
103         col_stride_(col_stride),
104         row_dilation_(row_dilation),
105         col_dilation_(col_dilation),
106         padding_(padding),
107         explicit_paddings_(explicit_paddings) {}
108 
109   template <typename OutputKernel>
operator()110   void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
111                   const Tensor& input, const Tensor& filter, Tensor* output) {
112     // Wrap output_kernel into type erased wrapper to reduce the number of
113     // unique template instantiations for Eigen Tensor contraction expressions.
114     OutputKernelWrapper output_kernel_wrapper(
115         [&output_kernel](
116             const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
117             const Eigen::TensorContractionParams& params, Eigen::Index i,
118             Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
119           output_kernel(output_mapper, params, i, j, num_rows, num_cols);
120         });
121 
122     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
123         row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
124       int conv_width = 1;  // Width for the convolution step.
125       for (int i = 0; i < 3; ++i) {
126         conv_width *= output->dim_size(i);
127       }
128 
129       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
130       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
131       functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
132           ctx->eigen_device<CPUDevice>(),
133           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
134           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
135           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
136           dim_pair, std::move(output_kernel_wrapper));
137 
138     } else if (filter.dim_size(0) == input.dim_size(1) &&
139                filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
140                col_dilation_ == 1 && padding_ == VALID) {
141       // If the input data and filter have the same height/width,
142       // reduce the 2D convolution to matrix multiplication.
143       const auto k =  // Length of reduction dimension.
144           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
145 
146       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
147       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
148       functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
149           ctx->eigen_device<CPUDevice>(),
150           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
151           input.shaped<T, 2>({input.dim_size(0), k}),
152           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
153           std::move(output_kernel_wrapper));
154 
155     } else {
156       if (padding_ == EXPLICIT) {
157         functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
158             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
159             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
160             col_stride_, row_dilation_, col_dilation_,
161             static_cast<int>(explicit_paddings_[2]),
162             static_cast<int>(explicit_paddings_[3]),
163             static_cast<int>(explicit_paddings_[4]),
164             static_cast<int>(explicit_paddings_[5]),
165             std::move(output_kernel_wrapper));
166       } else {
167         functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
168             ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
169             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
170             col_stride_, row_dilation_, col_dilation_,
171             BrainPadding2EigenPadding(padding_),
172             std::move(output_kernel_wrapper));
173       }
174     }
175   }
176 
177  private:
178   // Wrap output_kernel into type erased struct to reduce the number of unique
179   // template instantiations for Eigen Tensor contraction expressions.
180   //
181   // We do not pass std::function directly as an output kernel because it blows
182   // up the binary size in debug mode with super long symbol names.
183   struct OutputKernelWrapper {
184     using OutputKernelFn =
185         std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
186                            const Eigen::TensorContractionParams&, Eigen::Index,
187                            Eigen::Index, Eigen::Index, Eigen::Index)>;
188 
OutputKernelWrapperOutputKernelWrapper189     explicit OutputKernelWrapper(OutputKernelFn fn)
190         : output_kernel_fn(std::move(fn)) {}
191 
operatorOutputKernelWrapper192     void operator()(
193         const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
194         const Eigen::TensorContractionParams& params, Eigen::Index i,
195         Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
196       output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
197     }
198 
199     OutputKernelFn output_kernel_fn;
200   };
201 
202   int row_stride_;
203   int col_stride_;
204   int row_dilation_;
205   int col_dilation_;
206   const Padding padding_;
207   const std::vector<int64>& explicit_paddings_;
208 };
209 
210 template <typename T>
211 struct LaunchFusedConv2DOp<CPUDevice, T> {
212   void operator()(OpKernelContext* context, bool use_cudnn,
213                   bool cudnn_use_autotune, const Tensor& input,
214                   const Tensor& filter, const FusedComputationType fusion,
215                   const FusedComputationArgs& fusion_args,
216                   const Conv2DParameters& params,
217                   const Conv2DDimensions& dimensions, Tensor* output) {
218     OP_REQUIRES(context, dimensions.in_depth == filter.dim_size(2),
219                 errors::Unimplemented("Fused conv implementation does not "
220                                       "support grouped convolutions for now."));
221     OP_REQUIRES(context, params.data_format == FORMAT_NHWC,
222                 errors::Unimplemented("Fused conv implementation only supports "
223                                       "NHWC tensor format for now."));
224 
225     BiasAddArgs<T> bias_add_args;
226     if (BiasAddArgs<T>::IsSupported(fusion)) {
227       if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
228         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
229                                                 &fusion_args.leakyrelu_alpha));
230       } else {
231         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
232       }
233     }
234 
235     FusedBatchNormArgs<T> fused_batch_norm_args;
236     if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
237       if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) {
238         OP_REQUIRES_OK(context,
239                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
240                                               &fused_batch_norm_args,
241                                               &fusion_args.leakyrelu_alpha));
242       } else {
243         OP_REQUIRES_OK(context,
244                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
245                                               &fused_batch_norm_args));
246       }
247     }
248 
249     LaunchFusedConv2DWithOutputKernel<T> conv2d(
250         dimensions.stride_rows, dimensions.stride_cols,
251         dimensions.dilation_rows, dimensions.dilation_cols, params.padding,
252         params.explicit_paddings);
253 
254     switch (fusion) {
255       case FusedComputationType::kUndefined:
256         OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
257         break;
258       case FusedComputationType::kBiasAdd:
259         conv2d(WithBiasAdd<T>(bias_add_args), context, input, filter, output);
260         break;
261       case FusedComputationType::kBiasAddWithRelu:
262         conv2d(WithBiasAddAndRelu<T>(bias_add_args), context, input, filter,
263                output);
264         break;
265       case FusedComputationType::kBiasAddWithRelu6:
266         conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
267                output);
268         break;
269       case FusedComputationType::kBiasAddWithLeakyRelu:
270         conv2d(WithBiasAddAndLeakyRelu<T>(bias_add_args), context, input,
271                filter, output);
272         break;
273       case FusedComputationType::kBiasAddWithElu:
274         conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
275                output);
276         break;
277       case FusedComputationType::kFusedBatchNorm:
278         conv2d(
279             WithFusedBatchNorm<T>(fusion_args.epsilon, fused_batch_norm_args),
280             context, input, filter, output);
281         break;
282       case FusedComputationType::kFusedBatchNormWithRelu:
283         conv2d(WithFusedBatchNormAndRelu<T>(fusion_args.epsilon,
284                                             fused_batch_norm_args),
285                context, input, filter, output);
286         break;
287       case FusedComputationType::kFusedBatchNormWithRelu6:
288         conv2d(WithFusedBatchNormAndRelu6<T>(fusion_args.epsilon,
289                                              fused_batch_norm_args),
290                context, input, filter, output);
291         break;
292       case FusedComputationType::kFusedBatchNormWithLeakyRelu:
293         conv2d(WithFusedBatchNormAndLeakyRelu<T>(fusion_args.epsilon,
294                                                  fused_batch_norm_args),
295                context, input, filter, output);
296         break;
297       case FusedComputationType::kFusedBatchNormWithElu:
298         conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
299                                            fused_batch_norm_args),
300                context, input, filter, output);
301         break;
302     }
303   }
304 };
305 
306 #if GOOGLE_CUDA
307 
308 
309 inline int64 ConvolveScratchSize() {
310   static int64_t convolve_scratch_size = GetDnnWorkspaceLimit(
311       // default value is in bytes despite the name of the environment variable
312       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
313   );
314   return convolve_scratch_size;
315 }
316 
317 // Finds the best convolution algorithm for the given ConvLaunch (cuda
318 // convolution on the stream) and parameters, by running all possible
319 // algorithms and measuring execution time.
320 // TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
321 template <typename T, typename ConvLaunch, typename LogFunc>
322 Status FindBestConvolveAlgorithm(
323     const ConvParameters& params, const se::dnn::BatchDescriptor& input_desc,
324     const se::dnn::FilterDescriptor& filter_desc,
325     const se::dnn::BatchDescriptor& bias_desc,
326     const se::dnn::BatchDescriptor& output_desc,
327     const se::dnn::ConvolutionDescriptor& conv_desc,
328     const se::dnn::ActivationMode activation_mode, double conv_input_scale,
329     double side_input_scale, const ConvLaunch launch, OpKernelContext* context,
330     se::Stream* stream, se::DeviceMemory<T> output_ptr, const LogFunc& log,
331     se::dnn::AlgorithmConfig* algorithm_config) {
332   // Check if we already have an algorithm selected for the given parameters.
333   if (AutotuneConv::GetInstance()->Find(params, algorithm_config)) {
334     return Status::OK();
335   }
336   profiler::ScopedAnnotation trace("cudnn_autotuning");
337 
338   // Find all candidate algorithms or execution plans (for CuDNN frontend APIs).
339   std::vector<std::unique_ptr<se::dnn::ConvolveExecutionPlan>> plans;
340   std::vector<se::dnn::AlgorithmDesc> algorithms;
341   std::vector<se::dnn::AlgorithmConfig> configs;
342   if (CudnnUseFrontend()) {
343     if (!stream->parent()
344              ->GetFusedConvolveExecutionPlans(
345                  se::dnn::ConvolutionKind::FORWARD,
346                  se::dnn::ToDataType<T>::value, conv_input_scale,
347                  side_input_scale, stream, input_desc, filter_desc, bias_desc,
348                  output_desc, conv_desc, activation_mode, &plans)
349              .ok()) {
350       return errors::Unknown(
351           "Failed to get convolution plans. This is probably because cuDNN "
352           "failed to initialize, so try looking to see if a warning log "
353           "message was printed above.");
354     }
355     for (const auto& plan : plans) {
356       configs.push_back(se::dnn::AlgorithmConfig(
357           se::dnn::AlgorithmDesc{plan->getTag(), plan->get_raw_desc()},
358           plan->getWorkspaceSize()));
359     }
360   } else {
361     if (!stream->parent()->GetConvolveAlgorithms(&algorithms)) {
362       return errors::Unknown(
363           "Failed to get convolution algorithm. This is probably because cuDNN "
364           "failed to initialize, so try looking to see if a warning log "
365           "message was printed above.");
366     }
367     for (const auto& algorithm : algorithms) {
368       configs.push_back(se::dnn::AlgorithmConfig(algorithm));
369     }
370   }
371 
372   se::TfAllocatorAdapter tf_allocator_adapter(
373       context->device()->GetAllocator({}), stream);
374   se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
375                                     se::GpuAsmOpts());
376   se::DeviceMemory<T> output_ptr_rz(
377       WrapRedzoneBestEffort(&rz_allocator, output_ptr));
378 
379   std::vector<tensorflow::AutotuneResult> results;
380   for (const auto& profile_config : configs) {
381     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
382     se::RedzoneAllocator rz_scratch_allocator(
383         stream, &tf_allocator_adapter, se::GpuAsmOpts(),
384         /*memory_limit=*/ConvolveScratchSize());
385     se::ScratchAllocator* allocator_used =
386         !RedzoneCheckDisabled()
387             ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
388             : static_cast<se::ScratchAllocator*>(&scratch_allocator);
389     se::dnn::ProfileResult profile_result;
390 
391     Status cudnn_launch_status =
392         launch(profile_config, allocator_used, output_ptr_rz, &profile_result);
393 
394     if (cudnn_launch_status.ok() && profile_result.is_valid()) {
395       results.emplace_back();
396       auto& result = results.back();
397       if (CudnnUseFrontend()) {
398         result.mutable_cuda_conv_plan()->set_exec_plan_id(
399             profile_config.algorithm()->exec_plan_id());
400       } else {
401         result.mutable_conv()->set_algorithm(
402             profile_config.algorithm()->algo_id());
403         result.mutable_conv()->set_tensor_ops_enabled(
404             profile_config.algorithm()->tensor_ops_enabled());
405       }
406       result.set_scratch_bytes(
407           !RedzoneCheckDisabled()
408               ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
409               : scratch_allocator.TotalByteSize());
410       *result.mutable_run_time() = proto_utils::ToDurationProto(
411           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
412       CheckRedzones(rz_scratch_allocator, &result);
413       CheckRedzones(rz_allocator, &result);
414     } else if (CudnnUseFrontend()) {
415       // When CuDNN frontend APIs are used, we need to make sure the profiling
416       // results are one-to-one mapping with the "plans". So, we insert dummy
417       // results when the execution fails.
418       results.emplace_back();
419       auto& result = results.back();
420       result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
421       result.mutable_failure()->set_msg(
422           absl::StrCat("Profiling failure on CUDNN engine: ",
423                        profile_config.algorithm()->exec_plan_id()));
424     }
425   }
426   // Only log on an AutotuneConv cache miss.
427   log(results);
428   if (CudnnUseFrontend()) {
429     TF_RETURN_IF_ERROR(
430         BestCudnnConvAlgorithm(results, &plans, algorithm_config));
431   } else {
432     TF_RETURN_IF_ERROR(
433         BestCudnnConvAlgorithm(results, nullptr, algorithm_config));
434   }
435   AutotuneConv::GetInstance()->Insert(params, *algorithm_config);
436   return Status::OK();
437 }
438 
439 template <typename T>
440 struct LaunchFusedConv2DOp<GPUDevice, T> {
441   void operator()(OpKernelContext* context, bool use_cudnn,
442                   bool cudnn_use_autotune, const Tensor& input_param,
443                   const Tensor& filter, FusedComputationType fusion,
444                   const FusedComputationArgs& fusion_args,
445                   const Conv2DParameters& params,
446                   const Conv2DDimensions& dimensions, Tensor* output) {
447     OP_REQUIRES(
448         context,
449         params.data_format == FORMAT_NHWC || params.data_format == FORMAT_NCHW,
450         errors::Unimplemented("Fused conv implementation only supports "
451                               "NHWC and HCHW tensor formats for now."));
452 
453     auto* stream = context->op_device_context()->stream();
454     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
455     OP_REQUIRES(
456         context, use_cudnn,
457         errors::Unimplemented("FusedConv2D for GPU is not currently supported "
458                               "without cudnn"));
459 
460     OP_REQUIRES(
461         context, fusion == FusedComputationType::kBiasAddWithRelu,
462         errors::Unimplemented("FusedConv2D implementation only supports "
463                               "fusing with `BiasAdd + Relu` for now."));
464 
465     Tensor input = input_param;
466 
467     const int64_t in_batch = GetTensorDim(input, params.data_format, 'N');
468     int64_t in_rows = GetTensorDim(input, params.data_format, 'H');
469     int64_t in_cols = GetTensorDim(input, params.data_format, 'W');
470     const int64_t in_depths = GetTensorDim(input, params.data_format, 'C');
471 
472     const int64_t patch_rows = filter.dim_size(0);
473     const int64_t patch_cols = filter.dim_size(1);
474     const int64_t patch_depths = filter.dim_size(2);
475 
476     const int64_t out_batch = GetTensorDim(*output, params.data_format, 'N');
477     const int64_t out_rows = GetTensorDim(*output, params.data_format, 'H');
478     const int64_t out_cols = GetTensorDim(*output, params.data_format, 'W');
479     const int64_t out_depths = GetTensorDim(*output, params.data_format, 'C');
480 
481     // Bias of the following dimensions: [ output_depth ]
482     const Tensor& bias = context->input(2);
483     OP_REQUIRES(context, bias.dims() == 1,
484                 errors::InvalidArgument("bias must be 1-dimensional",
485                                         bias.shape().DebugString()));
486     OP_REQUIRES(context, bias.dim_size(0) == out_depths,
487                 errors::InvalidArgument("bias depth must be equal to out depth",
488                                         bias.shape().DebugString()));
489 
490     const int64_t common_padding_rows =
491         std::min(dimensions.pad_rows_before, dimensions.pad_rows_after);
492     const int64_t common_padding_cols =
493         std::min(dimensions.pad_cols_before, dimensions.pad_cols_after);
494     if (dimensions.pad_rows_before != dimensions.pad_rows_after ||
495         dimensions.pad_cols_before != dimensions.pad_cols_after) {
496       // cuDNN only supports padding the same amount on the left and right
497       // sides, and on the top and bottom sides. So we manually create a new
498       // padded input tensor such that we can pass it to cuDNN.
499 
500       // TODO(reedwm): In some cases, we can avoid an allocation even if the two
501       // padding sides are different. For example, if the input is 2x2, the
502       // filter is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the
503       // result is equivalent to as if the padding is (1, 1, 1, 1). Changing the
504       // padding in such a way would allow us to avoid the allocation.
505       Tensor transformed_input;
506       const int64_t padding_rows_diff =
507           std::abs(dimensions.pad_rows_after - dimensions.pad_rows_before);
508       const int64_t padding_cols_diff =
509           std::abs(dimensions.pad_cols_after - dimensions.pad_cols_before);
510       const int64_t new_in_rows = in_rows + padding_rows_diff;
511       const int64_t new_in_cols = in_cols + padding_cols_diff;
512       OP_REQUIRES_OK(context,
513                      context->allocate_temp(
514                          DataTypeToEnum<T>::value,
515                          ShapeFromFormat(params.data_format, in_batch,
516                                          new_in_rows, new_in_cols, in_depths),
517                          &transformed_input));
518       const int64_t input_pad_top =
519           dimensions.pad_rows_before - common_padding_rows;
520       const int64_t input_pad_bottom =
521           dimensions.pad_rows_after - common_padding_rows;
522       const int64_t input_pad_left =
523           dimensions.pad_cols_before - common_padding_cols;
524       const int64_t input_pad_right =
525           dimensions.pad_cols_after - common_padding_cols;
526       bool in_bounds =
527           FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
528           FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
529           FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
530           FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
531       if (!in_bounds) {
532         context->SetStatus(errors::InvalidArgument("Padding is too large."));
533         return;
534       }
535       functor::PadInput<GPUDevice, T, int, 4>()(
536           context->eigen_device<GPUDevice>(),
537           To32Bit(input_param.tensor<T, 4>()),
538           {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
539           {{static_cast<int>(input_pad_bottom),
540             static_cast<int>(input_pad_right)}},
541           To32Bit(transformed_input.tensor<T, 4>()), params.data_format, T{});
542       input = transformed_input;
543       in_rows = new_in_rows;
544       in_cols = new_in_cols;
545     }
546 
547     if (params.data_format == FORMAT_NHWC) {
548       // Convert the input tensor from NHWC to NCHW.
549       TensorShape nchw_shape =
550           ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
551       if (in_depths > 1) {
552         Tensor transformed_input;
553         OP_REQUIRES_OK(context,
554                        context->allocate_temp(DataTypeToEnum<T>::value,
555                                               nchw_shape, &transformed_input));
556         functor::NHWCToNCHW<GPUDevice, T, 4>()(
557             context->eigen_device<GPUDevice>(),
558             const_cast<const Tensor&>(input).tensor<T, 4>(),
559             transformed_input.tensor<T, 4>());
560         input = transformed_input;
561       } else {
562         // If depth <= 1, then just reshape.
563         CHECK(input.CopyFrom(input, nchw_shape));  // Crash OK
564       }
565     }
566 
567     CHECK(common_padding_rows >= 0) << "Negative padding rows";  // Crash OK
568     CHECK(common_padding_rows >= 0) << "Negative padding cols";  // Crash OK
569 
570     se::dnn::ActivationMode dnn_activation_mode;
571     switch (fusion) {
572       case FusedComputationType::kBiasAddWithRelu:
573         dnn_activation_mode = se::dnn::ActivationMode::kRelu;
574         break;
575       default:
576         LOG(FATAL) << "Unsupported fusion type";  // Crash OK
577     }
578 
579     se::dnn::BatchDescriptor input_desc;
580     input_desc.set_count(in_batch)
581         .set_feature_map_count(in_depths)
582         .set_height(in_rows)
583         .set_width(in_cols)
584         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
585     se::dnn::FilterDescriptor filter_desc;
586     filter_desc.set_input_filter_height(patch_rows)
587         .set_input_filter_width(patch_cols)
588         .set_input_feature_map_count(patch_depths)
589         .set_output_feature_map_count(filter.dim_size(3));
590     se::dnn::BatchDescriptor bias_desc;
591     bias_desc.set_count(1)
592         .set_height(1)
593         .set_width(1)
594         .set_feature_map_count(out_depths)
595         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
596     se::dnn::ConvolutionDescriptor conv_desc;
597     conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows)
598         .set_horizontal_dilation_rate(dimensions.dilation_cols)
599         .set_vertical_filter_stride(dimensions.stride_rows)
600         .set_horizontal_filter_stride(dimensions.stride_cols)
601         .set_zero_padding_height(common_padding_rows)
602         .set_zero_padding_width(common_padding_cols)
603         .set_group_count(in_depths / patch_depths);
604     se::dnn::BatchDescriptor output_desc;
605     output_desc.set_count(out_batch)
606         .set_height(out_rows)
607         .set_width(out_cols)
608         .set_feature_map_count(out_depths)
609         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
610 
611     Tensor transformed_filter;
612     OP_REQUIRES_OK(context,
613                    context->allocate_temp(
614                        DataTypeToEnum<T>::value,
615                        TensorShape({filter.dim_size(3), filter.dim_size(2),
616                                     filter.dim_size(0), filter.dim_size(1)}),
617                        &transformed_filter));
618     functor::TransformFilter<GPUDevice, T, int, 4>()(
619         context->eigen_device<GPUDevice>(), FORMAT_OIHW,
620         To32Bit(filter.tensor<T, 4>()),
621         To32Bit(transformed_filter.tensor<T, 4>()));
622 
623     Tensor transformed_output;
624     if (params.data_format == FORMAT_NHWC) {
625       // Only allocate temporary memory when a layout transformation is needed.
626       OP_REQUIRES_OK(context,
627                      context->allocate_temp(
628                          DataTypeToEnum<T>::value,
629                          ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
630                                          out_cols, out_depths),
631                          &transformed_output));
632     } else {
633       transformed_output = *output;
634     }
635 
636     const auto tensor_on_device = [](const Tensor& t) -> se::DeviceMemory<T> {
637       return AsDeviceMemory(t.template flat<T>().data(),
638                             t.template flat<T>().size());
639     };
640 
641     se::DeviceMemory<T> input_ptr = tensor_on_device(input);
642     se::DeviceMemory<T> filter_ptr = tensor_on_device(transformed_filter);
643     se::DeviceMemory<T> bias_ptr = tensor_on_device(bias);
644     se::DeviceMemory<T> output_ptr = tensor_on_device(transformed_output);
645 
646     // We do not use side inputs, so we can safely pass nullptr.
647     se::DeviceMemory<T> side_input_ptr =
648         AsDeviceMemory(static_cast<T*>(nullptr), 0);
649 
650     int device_id = stream->parent()->device_ordinal();
651     DataType dtype = input.dtype();
652     ConvParameters conv_parameters = {
653         in_batch,                      // batch
654         in_depths,                     // in_depths
655         {{in_rows,                     // in_rows
656           in_cols}},                   // in_cols
657         FORMAT_NCHW,                   // compute_data_format
658         out_depths,                    // out_depths
659         {{patch_rows,                  // filter_rows
660           patch_cols,                  // filter_cols
661           patch_depths}},              // filter_depths
662         {{dimensions.dilation_rows,    // dilation_rows
663           dimensions.dilation_cols}},  // dilation_cols
664         {{dimensions.stride_rows,      // stride_rows
665           dimensions.stride_cols}},    // stride_cols
666         {{common_padding_rows,         // padding_rows
667           common_padding_cols}},       // padding_cols
668         dtype,                         // tensor datatype
669         device_id,                     // device_id
670         conv_desc.group_count(),
671         ConvParameters::FusionInfo{
672             /*has_side_input=*/false,  // this op doesn't support side inputs.
673             dnn_activation_mode,       // activation_mode
674             /*is_contrib=*/false}};
675 
676     constexpr double kConvInputScale = 1.0;
677     constexpr double kSideInputScale = 0.0;
678     // Launch fused convolution with given parameters and scratch allocator.
679     // Record profile result into `profile_result` if it's not nullptr.
680     const auto launch = [&](se::dnn::AlgorithmConfig algorithm_config,
681                             se::ScratchAllocator* scratch_allocator,
682                             se::DeviceMemory<T> output_ptr_to_use,
683                             se::dnn::ProfileResult* profile_result) -> Status {
684       if (CudnnUseFrontend()) {
685         return stream->FusedConvolveWithExecutionPlan(
686             input_desc, input_ptr,            // input
687             kConvInputScale,                  // input_scale
688             filter_desc, filter_ptr,          // filter
689             conv_desc,                        // conv
690             side_input_ptr, kSideInputScale,  // side_input
691             bias_desc, bias_ptr,              // bias
692             dnn_activation_mode,              // activation
693             output_desc, &output_ptr_to_use,  // output
694             scratch_allocator, algorithm_config, profile_result);
695       } else {
696         return stream->FusedConvolveWithAlgorithm(
697             input_desc, input_ptr,            // input
698             kConvInputScale,                  // input_scale
699             filter_desc, filter_ptr,          // filter
700             conv_desc,                        // conv
701             side_input_ptr, kSideInputScale,  // side_input
702             bias_desc, bias_ptr,              // bias
703             dnn_activation_mode,              // activation
704             output_desc, &output_ptr_to_use,  // output
705             scratch_allocator, algorithm_config, profile_result);
706       }
707     };
708 
709     se::dnn::AlgorithmConfig algorithm_config;
710     if (cudnn_use_autotune) {
711       auto status = FindBestConvolveAlgorithm<T>(
712           conv_parameters, input_desc, filter_desc, bias_desc, output_desc,
713           conv_desc, dnn_activation_mode, kConvInputScale, kSideInputScale,
714           launch, context, stream, output_ptr,
715           [&](absl::Span<const tensorflow::AutotuneResult> results) {
716             LogFusedConvForwardAutotuneResults(
717                 se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
718                 output_ptr, bias_ptr, side_input_ptr, input_desc, filter_desc,
719                 output_desc, conv_desc, kConvInputScale, kSideInputScale,
720                 dnn_activation_mode, stream->parent(), results);
721           },
722           &algorithm_config);
723       OP_REQUIRES_OK(context, status);
724     }
725 
726     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
727     Status cudnn_launch_status = launch(algorithm_config, &scratch_allocator,
728                                         output_ptr, /*profile_result=*/nullptr);
729     OP_REQUIRES_OK(context, cudnn_launch_status);
730 
731     // Convert the output tensor back from NCHW to NHWC.
732     if (params.data_format == FORMAT_NHWC) {
733       functor::NCHWToNHWC<GPUDevice, T, 4>()(
734           context->eigen_device<GPUDevice>(),
735           const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
736           output->tensor<T, 4>());
737     }
738   }
739 };
740 
741 #endif  // GOOGLE_CUDA
742 
743 template <typename Device, typename T>
744 class FusedConv2DOp : public OpKernel {
745  public:
746   explicit FusedConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
747     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
748 
749     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
750     cudnn_use_autotune_ = CudnnUseAutotune();
751 
752     using FCT = FusedComputationType;
753 
754     std::vector<FusedComputationPattern> patterns;
755     if (std::is_same<Device, CPUDevice>::value) {
756       patterns = {
757           {FCT::kBiasAdd, {"BiasAdd"}},
758           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
759           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
760           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
761           {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
762           {FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
763           {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
764           {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
765           {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
766           {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}},
767       };
768     }
769 
770     // NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
771     // identity activation function, it in theory should allow to fuse
772     // convolution with BiasAdd, but in practice it doesn't work, cuDNN ignores
773     // this parameter and always does Relu activation.
774     if (std::is_same<Device, GPUDevice>::value) {
775       patterns = {{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}};
776     }
777 
778     OP_REQUIRES_OK(context, InitializeFusedComputation(
779                                 context, "Conv2D", patterns,
780                                 &fused_computation_, &fused_computation_args_));
781   }
782 
783   void Compute(OpKernelContext* context) override {
784     // Input tensor is of the following dimensions:
785     // [ batch, in_rows, in_cols, in_depth ]
786     const Tensor& input = context->input(0);
787 
788     // Input filter is of the following dimensions:
789     // [ filter_rows, filter_cols, in_depth, out_depth]
790     const Tensor& filter = context->input(1);
791 
792     Conv2DDimensions dimensions;
793     OP_REQUIRES_OK(context,
794                    ComputeConv2DDimension(params_, input, filter, &dimensions));
795 
796     TensorShape out_shape = ShapeFromFormat(
797         params_.data_format, dimensions.batch, dimensions.out_rows,
798         dimensions.out_cols, dimensions.out_depth);
799 
800     // Output tensor is of the following dimensions:
801     // [ in_batch, out_rows, out_cols, out_depth ]
802     Tensor* output = nullptr;
803     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
804 
805     VLOG(2) << "FusedConv2D: in_depth = " << dimensions.in_depth
806             << ", patch_depth = " << dimensions.patch_depth
807             << ", input_cols = " << dimensions.input_cols
808             << ", filter_cols = " << dimensions.filter_cols
809             << ", input_rows = " << dimensions.input_rows
810             << ", filter_rows = " << dimensions.filter_rows
811             << ", stride_rows = " << dimensions.stride_rows
812             << ", stride_cols = " << dimensions.stride_cols
813             << ", dilation_rows = " << dimensions.dilation_rows
814             << ", dilation_cols = " << dimensions.dilation_cols
815             << ", out_depth = " << dimensions.out_depth;
816 
817     // If there is nothing to compute, return.
818     if (out_shape.num_elements() == 0) {
819       return;
820     }
821 
822     LaunchFusedConv2DOp<Device, T>()(context, use_cudnn_, cudnn_use_autotune_,
823                                      input, filter, fused_computation_,
824                                      fused_computation_args_, params_,
825                                      dimensions, output);
826   }
827 
828  private:
829   Conv2DParameters params_;
830   bool use_cudnn_;
831   bool cudnn_use_autotune_;
832 
833   FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
834   FusedComputationArgs fused_computation_args_;
835 
836   TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DOp);
837 };
838 
839 // Registration of the CPU implementations.
840 #define REGISTER_FUSED_CPU_CONV2D(T)                                  \
841   REGISTER_KERNEL_BUILDER(                                            \
842       Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
843       FusedConv2DOp<CPUDevice, T>);
844 
845 #if GOOGLE_CUDA
846 
847 #define DECLARE_FUNCTOR_GPU_SPEC(T)                                     \
848   template <>                                                           \
849   void TransformFilter<GPUDevice, T, int, 4>::operator()(               \
850       const GPUDevice& d, FilterTensorFormat dst_filter_format,         \
851       typename TTypes<T, 4, int>::ConstTensor in,                       \
852       typename TTypes<T, 4, int>::Tensor out);                          \
853   extern template struct TransformFilter<GPUDevice, T, int, 4>;         \
854   template <>                                                           \
855   void PadInput<GPUDevice, T, int, 4>::operator()(                      \
856       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,   \
857       const std::array<int, 2>& padding_left,                           \
858       const std::array<int, 2>& padding_right,                          \
859       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
860       const T& padding_value);                                          \
861   extern template struct PadInput<GPUDevice, T, int, 4>
862 
863 // Registration of the GPU implementations.
864 #define REGISTER_FUSED_GPU_CONV2D(T)                                  \
865   REGISTER_KERNEL_BUILDER(                                            \
866       Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
867       FusedConv2DOp<GPUDevice, T>);
868 
869 #endif  // GOOGLE_CUDA
870 
871 }  // namespace tensorflow
872 
873 #endif  // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
874