• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements convolution operations with other kernels baked into the
17 // processing, to optimize latency and memory usage:
18 //  - Conv2D + BiasAdd + <Activation>
19 //  - Conv2D + FusedBatchNorm + <Activation>
20 //
21 // Activation: Relu, Relu6, Elu, etc...
22 //
23 // Kernels for convolutions fused with image transformations (resize and mirror
24 // padding) defined in `conv_ops_fused_image_transform.cc`.
25 //
26 // For the CPU device we implement fusion with an Eigen tensor contraction
27 // output kernel. For the GPU device we rely on CuDNN primitives.
28 //
29 // NOTE: GPU only supports fusion of Conv2D + BiasAdd + <optional Relu>.
30 
31 #ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
32 #define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
33 
34 #define USE_EIGEN_TENSOR
35 #define EIGEN_USE_THREADS
36 
37 #if GOOGLE_CUDA
38 #define EIGEN_USE_GPU
39 #endif  // GOOGLE_CUDA
40 
41 #include <string>
42 #include <vector>
43 
44 #include "absl/strings/str_cat.h"
45 #include "absl/strings/str_join.h"
46 #include "absl/strings/substitute.h"
47 #include "tensorflow/core/framework/bounds_check.h"
48 #include "tensorflow/core/framework/op_kernel.h"
49 #include "tensorflow/core/framework/register_types.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_shape.h"
52 #include "tensorflow/core/kernels/conv_2d.h"
53 #include "tensorflow/core/kernels/conv_ops.h"
54 #include "tensorflow/core/kernels/ops_util.h"
55 #include "tensorflow/core/util/tensor_format.h"
56 #include "tensorflow/core/util/use_cudnn.h"
57 
58 #if GOOGLE_CUDA
59 #include "cuda/include/cudnn.h"
60 #include "tensorflow/core/kernels/conv_ops_gpu.h"
61 #include "tensorflow/core/platform/stream_executor.h"
62 #include "tensorflow/core/util/proto/proto_utils.h"
63 #endif  // GOOGLE_CUDA
64 
65 namespace tensorflow {
66 
67 class AutotuneResult;
68 
69 typedef Eigen::ThreadPoolDevice CPUDevice;
70 typedef Eigen::GpuDevice GPUDevice;
71 
72 // Supported Conv2D fusions. Not all of them supported on all type of devices.
73 enum class FusedComputationType {
74   // NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
75   // identity activation function, it in theory should allow to fuse convolution
76   // with BiasAdd, but in practice it doesn't work, cuDNN ignores this parameter
77   // and always does Relu activation.
78   kBiasAdd,                // CPU
79   kBiasAddWithRelu,        // CPU and GPU
80   kFusedBatchNorm,         // CPU only
81   kFusedBatchNormWithRelu  // CPU only
82 };
83 
84 // We have to pass around additional arguments for all possible fusion types.
85 struct FusedComputationArgs {
86   float epsilon = 0.0;  // Used by `FusedBatchNorm` fusion only
87 };
88 
89 template <typename Device, typename T>
90 struct LaunchFusedConv2DOp {
91   void operator()(OpKernelContext* context, bool use_cudnn,
92                   bool cudnn_use_autotune, const Tensor& input,
93                   const Tensor& filter, FusedComputationType fusion,
94                   const FusedComputationArgs& fusion_args,
95                   const Conv2DParameters& params,
96                   const Conv2DDimensions& dimensions, Tensor* output);
97 };
98 
99 // Type aliases for the unaligned tensors (tensor maps) used in output kernels.
100 template <typename T>
101 struct Unaligned {
102   // There is no guarantee that the output block passed to the output kernel
103   // will be aligned.
104 
105   using Tensor =
106       Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>,
107                        Eigen::Unaligned>;
108 
109   using ConstTensor = Eigen::TensorMap<
110       Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>,
111       Eigen::Unaligned>;
112 };
113 
114 // Type alias for the tensor contraction output mapper.
115 template <typename Scalar, typename Index>
116 using ContractionOutputMapper =
117     Eigen::internal::blas_data_mapper<Scalar, Index, Eigen::ColMajor>;
118 
119 // Returns input expression without any transformations.
120 struct Identity {
121   template <typename XprType>
122   static auto apply(XprType expr) -> XprType {
123     return expr;
124   };
125 };
126 
127 // Applies `Relu` to the passed input expression.
128 struct Relu {
129   template <typename XprType>
130   static auto apply(XprType expr)
131       -> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())) {
132     return expr.cwiseMax(static_cast<typename XprType::Scalar>(0));
133   };
134 };
135 
136 // TensorContraction swaps lhs with rhs, and changes layout from RowMajor
137 // (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul
138 // using these tensors.
139 //
140 // TensorContraction output matrix (before reshape) has a ColMajor layout, and
141 // has dimensions:
142 //  - rows: output_channels
143 //  - cols: all other dimensions
144 //
145 // First element in every column is:
146 //   [batch ??, height ??, width ??, out_channel = i]
147 //
148 // We do not know what are the values of the 'batch', 'height', and 'width' here
149 // (if we know original dimensions, they can be computed from 'j').
150 //
151 // Each column of an output block is a continuous slice along the output channel
152 // dimension, so we can use it to efficiently compute any transformation that
153 // depends only on a channel value (e.g. add channel bias).
154 
155 // Output kernel that fuses BiasAdd operation into the output of tensor
156 // contraction + activation function defined by Activation.
157 template <typename T, typename Activation = Identity>
158 struct BiasAddOutputKernel {
BiasAddOutputKernelBiasAddOutputKernel159   explicit BiasAddOutputKernel(const T* bias_data) : bias_data(bias_data) {}
160 
161   template <typename Index, typename Scalar>
operatorBiasAddOutputKernel162   EIGEN_ALWAYS_INLINE void operator()(
163       const ContractionOutputMapper<Scalar, Index>& output_mapper,
164       const Eigen::TensorContractionParams& params, Index i, Index j,
165       Index num_rows, Index num_cols) const {
166     DCHECK(params.swapped_arguments);
167 
168     const T* bias_base = bias_data + i;
169     typename Unaligned<T>::ConstTensor bias(bias_base, num_rows);
170 
171     for (int col = 0; col < num_cols; ++col) {
172       T* output_base = &output_mapper(0, col);
173       typename Unaligned<T>::Tensor output(output_base, num_rows);
174       const auto expr = output + bias;
175       output = Activation::template apply<decltype(expr)>(expr);
176     }
177   }
178 
179  private:
180   const T* bias_data;
181 };
182 
183 // Output kernel that fuses FusedBatchNorm operation into the output of tensor
184 // contraction + activation function defined by Activation.
185 template <typename T, typename Activation = Identity>
186 struct FusedBatchNormOutputKernel {
FusedBatchNormOutputKernelFusedBatchNormOutputKernel187   FusedBatchNormOutputKernel(T epsilon, const T* scaling_factor_data,
188                              const T* offset_data, const T* estimated_mean_data)
189       : epsilon(epsilon),
190         scaling_factor_data(scaling_factor_data),
191         offset_data(offset_data),
192         estimated_mean_data(estimated_mean_data) {}
193 
194   template <typename Index, typename Scalar>
operatorFusedBatchNormOutputKernel195   EIGEN_ALWAYS_INLINE void operator()(
196       const ContractionOutputMapper<Scalar, Index>& output_mapper,
197       const Eigen::TensorContractionParams& params, Index i, Index j,
198       Index num_rows, Index num_cols) const {
199     DCHECK(params.swapped_arguments);
200 
201     const T* scaling_factor_base = scaling_factor_data + i;
202     const T* offset_base = offset_data + i;
203     const T* mean_base = estimated_mean_data + i;
204 
205     typename Unaligned<T>::ConstTensor scaling_factor(scaling_factor_base,
206                                                       num_rows);
207     typename Unaligned<T>::ConstTensor offset(offset_base, num_rows);
208     typename Unaligned<T>::ConstTensor mean(mean_base, num_rows);
209 
210     for (int col = 0; col < num_cols; ++col) {
211       T* output_base = &output_mapper(0, col);
212       typename Unaligned<T>::Tensor output(output_base, num_rows);
213 
214       auto scaled = (output - mean) * scaling_factor;
215       auto shifted = scaled + offset;
216 
217       output = Activation::template apply<decltype(shifted)>(shifted);
218     }
219   }
220 
221  private:
222   T epsilon;
223   const T* scaling_factor_data;
224   const T* offset_data;
225   const T* estimated_mean_data;
226 };
227 
228 // Type aliases for the output kernels, purely for the sake of better launch
229 // dispatching code readability.
230 template <typename T>
231 using WithBiasAdd = BiasAddOutputKernel<T>;
232 template <typename T>
233 using WithBiasAddAndRelu = BiasAddOutputKernel<T, Relu>;
234 template <typename T>
235 using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
236 template <typename T>
237 using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
238 
239 // This is CPU-only implementation that uses Eigen contraction output kernels.
240 //
241 // Dispatch 2D convolution to the appropriate primitive operation:
242 //   (1) MatMul for the case of 1x1 convolution.
243 //   (2) MatMul for the case when filter size equals to the input size.
244 //   (3) General spatial 2D convolution for all other cases.
245 template <typename T>
246 class LaunchFusedConv2DWithOutputKernel {
247  public:
LaunchFusedConv2DWithOutputKernel(int row_stride,int col_stride,int row_dilation,int col_dilation,Padding padding)248   LaunchFusedConv2DWithOutputKernel(int row_stride, int col_stride,      //
249                                     int row_dilation, int col_dilation,  //
250                                     Padding padding)
251       : row_stride_(row_stride),
252         col_stride_(col_stride),
253         row_dilation_(row_dilation),
254         col_dilation_(col_dilation),
255         padding_(padding) {}
256 
257   template <typename OutputKernel>
operator()258   void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
259                   const Tensor& input, const Tensor& filter, Tensor* output) {
260     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
261         row_stride_ == 1 && col_stride_ == 1) {
262       int conv_width = 1;  // Width for the convolution step.
263       for (int i = 0; i < 3; ++i) {
264         conv_width *= output->dim_size(i);
265       }
266 
267       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
268       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
269       functor::MatMulConvFunctor<CPUDevice, T, OutputKernel>()(
270           ctx->eigen_device<CPUDevice>(),
271           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
272           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
273           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
274           dim_pair, output_kernel);
275 
276     } else if (filter.dim_size(0) == input.dim_size(1) &&
277                filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
278                col_dilation_ == 1 && padding_ == VALID) {
279       // If the input data and filter have the same height/width,
280       // reduce the 2D convolution to matrix multiplication.
281       const auto k =  // Length of reduction dimension.
282           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
283 
284       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
285       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
286       functor::MatMulConvFunctor<CPUDevice, T, OutputKernel>()(
287           ctx->eigen_device<CPUDevice>(),
288           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
289           input.shaped<T, 2>({input.dim_size(0), k}),
290           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
291           output_kernel);
292 
293     } else {
294       functor::SpatialConvolution<CPUDevice, T, OutputKernel>()(
295           ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
296           input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_, col_stride_,
297           row_dilation_, col_dilation_, BrainPadding2EigenPadding(padding_),
298           output_kernel);
299     }
300   }
301 
302  private:
303   int row_stride_;
304   int col_stride_;
305   int row_dilation_;
306   int col_dilation_;
307   const Padding padding_;
308 };
309 
310 template <typename T>
311 struct LaunchFusedConv2DOp<CPUDevice, T> {
312   void operator()(OpKernelContext* context, bool use_cudnn,
313                   bool cudnn_use_autotune, const Tensor& input,
314                   const Tensor& filter, const FusedComputationType fusion,
315                   const FusedComputationArgs& fusion_args,
316                   const Conv2DParameters& params,
317                   const Conv2DDimensions& dimensions, Tensor* output) {
318     OP_REQUIRES(context, dimensions.in_depth == filter.dim_size(2),
319                 errors::Unimplemented("Fused conv implementation does not "
320                                       "support grouped convolutions for now."));
321     OP_REQUIRES(context, params.data_format == FORMAT_NHWC,
322                 errors::Unimplemented("Fused conv implementation only supports "
323                                       "NHWC tensor format for now."));
324 
325     BiasAddArgs bias_add;
326     FusedBatchNormArgs fused_batch_norm;
327 
328     LaunchFusedConv2DWithOutputKernel<T> conv2d(
329         dimensions.stride_rows, dimensions.stride_cols,
330         dimensions.dilation_rows, dimensions.dilation_cols, params.padding);
331 
332     switch (fusion) {
333       case FusedComputationType::kBiasAdd:
334         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
335         conv2d(WithBiasAdd<T>(bias_add.bias_add_data), context, input, filter,
336                output);
337         break;
338 
339       case FusedComputationType::kBiasAddWithRelu:
340         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
341         conv2d(WithBiasAddAndRelu<T>(bias_add.bias_add_data), context, input,
342                filter, output);
343         break;
344 
345       case FusedComputationType::kFusedBatchNorm:
346         OP_REQUIRES_OK(context,
347                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
348                                               &fused_batch_norm));
349         conv2d(WithFusedBatchNorm<T>(fusion_args.epsilon,
350                                      fused_batch_norm.scaling_factor.data(),
351                                      fused_batch_norm.offset_data,
352                                      fused_batch_norm.estimated_mean_data),
353                context, input, filter, output);
354         break;
355 
356       case FusedComputationType::kFusedBatchNormWithRelu:
357         OP_REQUIRES_OK(context,
358                        InitFusedBatchNormArgs(context, fusion_args.epsilon,
359                                               &fused_batch_norm));
360         conv2d(WithFusedBatchNormAndRelu<T>(
361                    fusion_args.epsilon, fused_batch_norm.scaling_factor.data(),
362                    fused_batch_norm.offset_data,
363                    fused_batch_norm.estimated_mean_data),
364                context, input, filter, output);
365         break;
366     }
367   }
368 
369  private:
370   struct BiasAddArgs {
371     const T* bias_add_data = nullptr;
372   };
373 
374   struct FusedBatchNormArgs {
375     const T* scale_data = nullptr;
376     const T* offset_data = nullptr;
377     const T* estimated_mean_data = nullptr;
378     const T* estimated_variance_data = nullptr;
379 
380     // Precomputed expression:
381     //   scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
382     Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
383   };
384 
385 #define TF_REQUIRES(EXP, STATUS) \
386   if (!TF_PREDICT_TRUE(EXP)) return (STATUS)
387 
388   void InitDataPtr(const Tensor& tensor, const T** ptr) const {
389     *ptr = reinterpret_cast<const T*>(tensor.tensor_data().data());
390   }
391 
392   Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args) const {
393     // Bias of the following dimensions: [ output_depth ]
394     const Tensor& bias = context->input(2);
395 
396     TF_REQUIRES(bias.dims() == 1,
397                 errors::InvalidArgument("bias must be 1-dimensional",
398                                         bias.shape().DebugString()));
399 
400     InitDataPtr(bias, &args->bias_add_data);
401 
402     return Status::OK();
403   }
404 
405   Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
406                                 FusedBatchNormArgs* args) const {
407     const Tensor& scale = context->input(2);
408     const Tensor& offset = context->input(3);
409     const Tensor& estimated_mean = context->input(4);
410     const Tensor& estimated_variance = context->input(5);
411 
412     TF_REQUIRES(scale.dims() == 1,
413                 errors::InvalidArgument("scale must be 1-dimensional",
414                                         scale.shape().DebugString()));
415     TF_REQUIRES(offset.dims() == 1,
416                 errors::InvalidArgument("offset must be 1-dimensional",
417                                         offset.shape().DebugString()));
418     TF_REQUIRES(estimated_mean.dims() == 1,
419                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
420                                         estimated_mean.shape().DebugString()));
421     TF_REQUIRES(
422         estimated_variance.dims() == 1,
423         errors::InvalidArgument("estimated_variance must be 1-dimensional",
424                                 estimated_variance.shape().DebugString()));
425 
426     InitDataPtr(scale, &args->scale_data);
427     InitDataPtr(offset, &args->offset_data);
428     InitDataPtr(estimated_mean, &args->estimated_mean_data);
429     InitDataPtr(estimated_variance, &args->estimated_variance_data);
430 
431     // Precompute scaling factor once for all output blocks (kernels).
432     args->scaling_factor =
433         (estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
434         scale.flat<T>();
435 
436     return Status::OK();
437   }
438 
439 #undef TF_REQUIRES
440 };
441 
442 #if GOOGLE_CUDA
443 
444 // Encapsulate the default shape information that is used by the convolution
445 // operation, and add an activation mode for the fusion.
446 class FusedConvParameters : public ConvParameters {
447  public:
448   FusedConvParameters(const ConvParameters& base,
449                       const se::dnn::ActivationMode activation_mode)
450       : ConvParameters(base), activation_mode_(activation_mode) {}
451 
452   string ToString() const {
453     return absl::StrCat(ConvParameters::ToString(), ", ", activation_mode_);
454   }
455 
456  private:
457   friend bool operator==(const FusedConvParameters& lhs,
458                          const FusedConvParameters& rhs);
459 
460   using ParameterDataType =
461       std::tuple<ConvParameters::ParameterDataType, se::dnn::ActivationMode>;
462 
463   ParameterDataType get_data_as_tuple() const {
464     return std::make_tuple(ConvParameters::get_data_as_tuple(),
465                            activation_mode_);
466   }
467 
468   se::dnn::ActivationMode activation_mode_;
469 };
470 
471 inline bool operator==(const FusedConvParameters& lhs,
472                 const FusedConvParameters& rhs) {
473   return lhs.get_data_as_tuple() == rhs.get_data_as_tuple();
474 }
475 
476 inline bool operator!=(const FusedConvParameters& lhs,
477                 const FusedConvParameters& rhs) {
478   return !(lhs == rhs);
479 }
480 
481 // A dummy type to group forward convolution autotune results together.
482 struct FusedConvAutoTuneGroup {
483   static string name() { return "FusedConv"; }
484 };
485 
486 using AutoTuneFusedConv =
487     AutoTuneSingleton<FusedConvAutoTuneGroup, FusedConvParameters,
488                       se::dnn::AlgorithmConfig>;
489 
490 inline int64 ConvolveScratchSize() {
491   static int64 convolve_scratch_size = GetDnnWorkspaceLimit(
492       // default value is in bytes despite the name of the environment variable
493       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
494   );
495   return convolve_scratch_size;
496 }
497 
498 // Finds the best convolutiun algorithm for the given ConvLaunch (cuda
499 // convolution on the stream) and parameters, by running all possible
500 // algorithms and measuring execution time.
501 // TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
502 template <typename T, typename ConvLaunch, typename LogFunc>
503 Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
504                                  const ConvLaunch launch,
505                                  OpKernelContext* context, se::Stream* stream,
506                                  const LogFunc& log,
507                                  se::dnn::AlgorithmConfig* algorithm_config) {
508   // Check if we already have an algorithm selected for the given parameters.
509   if (AutoTuneFusedConv::GetInstance()->Find(params, algorithm_config)) {
510     return Status::OK();
511   }
512 
513   // Find all candidate algorithms.
514   std::vector<se::dnn::AlgorithmDesc> algorithms;
515   if (!stream->parent()->GetConvolveAlgorithms(
516           params.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
517           &algorithms)) {
518     return errors::Unknown(
519         "Failed to get convolution algorithm. This is probably "
520         "because cuDNN failed to initialize, so try looking to "
521         "see if a warning log message was printed above.");
522   }
523 
524   std::vector<tensorflow::AutotuneResult> results;
525   for (auto profile_algorithm : algorithms) {
526     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
527     se::dnn::ProfileResult profile_result;
528 
529     bool cudnn_launch_status =
530         launch(se::dnn::AlgorithmConfig(profile_algorithm), &scratch_allocator,
531                &profile_result);
532 
533     if (cudnn_launch_status && profile_result.is_valid()) {
534       results.emplace_back();
535       auto& result = results.back();
536       result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
537       result.mutable_conv()->set_tensor_ops_enabled(
538           profile_algorithm.tensor_ops_enabled());
539       result.mutable_success()->set_scratch_bytes(
540           scratch_allocator.TotalByteSize());
541       *result.mutable_success()->mutable_run_time() =
542           proto_utils::ToDurationProto(
543               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
544     }
545   }
546   // Only log on an AutoTuneFusedConv cache miss.
547   log(results);
548   TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(results, algorithm_config));
549   AutoTuneFusedConv::GetInstance()->Insert(params, *algorithm_config);
550   return Status::OK();
551 }
552 
553 template <typename T>
554 struct LaunchFusedConv2DOp<GPUDevice, T> {
555   void operator()(OpKernelContext* context, bool use_cudnn,
556                   bool cudnn_use_autotune, const Tensor& input_param,
557                   const Tensor& filter, FusedComputationType fusion,
558                   const FusedComputationArgs& fusion_args,
559                   const Conv2DParameters& params,
560                   const Conv2DDimensions& dimensions, Tensor* output) {
561     OP_REQUIRES(
562         context,
563         params.data_format == FORMAT_NHWC || params.data_format == FORMAT_NCHW,
564         errors::Unimplemented("Fused conv implementation only supports "
565                               "NHWC and HCHW tensor formats for now."));
566 
567     auto* stream = context->op_device_context()->stream();
568     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
569     OP_REQUIRES(
570         context, use_cudnn,
571         errors::Unimplemented("FusedConv2D for GPU is not currently supported "
572                               "without cudnn"));
573 
574     OP_REQUIRES(
575         context, fusion == FusedComputationType::kBiasAddWithRelu,
576         errors::Unimplemented("FusedConv2D implementation only supports "
577                               "fusing with `BiasAdd + Relu` for now."));
578 
579     Tensor input = input_param;
580 
581     const int64 in_batch = GetTensorDim(input, params.data_format, 'N');
582     int64 in_rows = GetTensorDim(input, params.data_format, 'H');
583     int64 in_cols = GetTensorDim(input, params.data_format, 'W');
584     const int64 in_depths = GetTensorDim(input, params.data_format, 'C');
585 
586     const int64 patch_rows = filter.dim_size(0);
587     const int64 patch_cols = filter.dim_size(1);
588     const int64 patch_depths = filter.dim_size(2);
589 
590     int64 padding_rows = 0;
591     int64 padding_cols = 0;
592     const int64 out_batch = GetTensorDim(*output, params.data_format, 'N');
593     const int64 out_rows = GetTensorDim(*output, params.data_format, 'H');
594     const int64 out_cols = GetTensorDim(*output, params.data_format, 'W');
595     const int64 out_depths = GetTensorDim(*output, params.data_format, 'C');
596 
597     // Bias of the following dimensions: [ output_depth ]
598     const Tensor& bias = context->input(2);
599     OP_REQUIRES(context, bias.dims() == 1,
600                 errors::InvalidArgument("bias must be 1-dimensional",
601                                         bias.shape().DebugString()));
602     OP_REQUIRES(context, bias.dim_size(0) == out_depths,
603                 errors::InvalidArgument("bias depth must be equal to out depth",
604                                         bias.shape().DebugString()));
605 
606     if (params.padding == SAME) {
607       // Total padding on rows and cols is
608       // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
609       // Pc = (C' - 1) * S + (Kc - 1) * Dc + 1 - C
610       // where (R', C') are output dimensions, (R, C) are input dimensions, S
611       // is stride, (Dr, Dc) are dilations, (Kr, Kc) are filter dimensions.
612       // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
613       // and Pc - Pc/2 on the bottom.  When Pr or Pc is odd, this means
614       // we pad more on the right and bottom than on the top and left.
615       padding_rows = std::max<int>(
616           0, (out_rows - 1) * dimensions.stride_rows +
617                  (patch_rows - 1) * dimensions.dilation_rows + 1 - in_rows);
618       padding_cols = std::max<int>(
619           0, (out_cols - 1) * dimensions.stride_cols +
620                  (patch_cols - 1) * dimensions.dilation_cols + 1 - in_cols);
621       const bool rows_odd = (padding_rows % 2 != 0);
622       const bool cols_odd = (padding_cols % 2 != 0);
623       if (rows_odd || cols_odd) {
624         Tensor transformed_input;
625         int64 new_in_rows = in_rows + rows_odd;
626         int64 new_in_cols = in_cols + cols_odd;
627         OP_REQUIRES_OK(context,
628                        context->allocate_temp(
629                            DataTypeToEnum<T>::value,
630                            ShapeFromFormat(params.data_format, in_batch,
631                                            new_in_rows, new_in_cols, in_depths),
632                            &transformed_input));
633 
634         functor::PadInput<GPUDevice, T, int, 4>()(
635             context->eigen_device<GPUDevice>(),
636             To32Bit(input_param.tensor<T, 4>()), {{0, 0}},
637             {{rows_odd, cols_odd}}, To32Bit(transformed_input.tensor<T, 4>()),
638             params.data_format);
639 
640         input = transformed_input;
641         in_rows = new_in_rows;
642         in_cols = new_in_cols;
643       }
644     }
645 
646     if (params.data_format == FORMAT_NHWC) {
647       // Convert the input tensor from NHWC to NCHW.
648       TensorShape nchw_shape =
649           ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
650       if (in_depths > 1) {
651         Tensor transformed_input;
652         OP_REQUIRES_OK(context,
653                        context->allocate_temp(DataTypeToEnum<T>::value,
654                                               nchw_shape, &transformed_input));
655         functor::NHWCToNCHW<GPUDevice, T, 4>()(
656             context->eigen_device<GPUDevice>(),
657             const_cast<const Tensor&>(input).tensor<T, 4>(),
658             transformed_input.tensor<T, 4>());
659         input = transformed_input;
660       } else {
661         // If depth <= 1, then just reshape.
662         CHECK(input.CopyFrom(input, nchw_shape));  // Crash OK
663       }
664     }
665 
666     CHECK(padding_rows >= 0) << "Negative padding rows";  // Crash OK
667     CHECK(padding_cols >= 0) << "Negative padding cols";  // Crash OK
668 
669     se::dnn::ActivationMode dnn_activation_mode;
670     switch (fusion) {
671       case FusedComputationType::kBiasAddWithRelu:
672         dnn_activation_mode = se::dnn::ActivationMode::kRelu;
673         break;
674       default:
675         LOG(FATAL) << "Unsupported fusion type";  // Crash OK
676     }
677 
678     se::dnn::BatchDescriptor input_desc;
679     input_desc.set_count(in_batch)
680         .set_feature_map_count(in_depths)
681         .set_height(in_rows)
682         .set_width(in_cols)
683         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
684     se::dnn::FilterDescriptor filter_desc;
685     filter_desc.set_input_filter_height(patch_rows)
686         .set_input_filter_width(patch_cols)
687         .set_input_feature_map_count(patch_depths)
688         .set_output_feature_map_count(filter.dim_size(3));
689     se::dnn::BatchDescriptor bias_desc;
690     bias_desc.set_count(1)
691         .set_height(1)
692         .set_width(1)
693         .set_feature_map_count(out_depths)
694         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
695     se::dnn::ConvolutionDescriptor conv_desc;
696     conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows)
697         .set_horizontal_dilation_rate(dimensions.dilation_cols)
698         .set_vertical_filter_stride(dimensions.stride_rows)
699         .set_horizontal_filter_stride(dimensions.stride_cols)
700         .set_zero_padding_height(padding_rows / 2)
701         .set_zero_padding_width(padding_cols / 2)
702         .set_group_count(in_depths / patch_depths);
703     se::dnn::BatchDescriptor output_desc;
704     output_desc.set_count(out_batch)
705         .set_height(out_rows)
706         .set_width(out_cols)
707         .set_feature_map_count(out_depths)
708         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
709 
710     Tensor transformed_filter;
711     OP_REQUIRES_OK(context,
712                    context->allocate_temp(
713                        DataTypeToEnum<T>::value,
714                        TensorShape({filter.dim_size(3), filter.dim_size(2),
715                                     filter.dim_size(0), filter.dim_size(1)}),
716                        &transformed_filter));
717     functor::TransformFilter<GPUDevice, T, int, 4>()(
718         context->eigen_device<GPUDevice>(), FORMAT_OIHW,
719         To32Bit(filter.tensor<T, 4>()),
720         To32Bit(transformed_filter.tensor<T, 4>()));
721 
722     Tensor transformed_output;
723     if (params.data_format == FORMAT_NHWC) {
724       // Only allocate temporary memory when a layout transformation is needed.
725       OP_REQUIRES_OK(context,
726                      context->allocate_temp(
727                          DataTypeToEnum<T>::value,
728                          ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
729                                          out_cols, out_depths),
730                          &transformed_output));
731     } else {
732       transformed_output = *output;
733     }
734 
735     const auto tensor_on_device = [](const Tensor& t) -> se::DeviceMemory<T> {
736       return AsDeviceMemory(t.template flat<T>().data(),
737                             t.template flat<T>().size());
738     };
739 
740     se::DeviceMemory<T> input_ptr = tensor_on_device(input);
741     se::DeviceMemory<T> filter_ptr = tensor_on_device(transformed_filter);
742     se::DeviceMemory<T> bias_ptr = tensor_on_device(bias);
743     se::DeviceMemory<T> output_ptr = tensor_on_device(transformed_output);
744 
745     // We do not use side inputs, so we can safely pass nullptr.
746     se::DeviceMemory<T> side_input_ptr =
747         AsDeviceMemory(static_cast<T*>(nullptr), 0);
748 
749     int device_id = stream->parent()->device_ordinal();
750     DataType dtype = input.dtype();
751     FusedConvParameters conv_parameters = {
752         {
753             in_batch,                      // batch
754             in_depths,                     // in_depths
755             {{in_rows,                     // in_rows
756               in_cols}},                   // in_cols
757             FORMAT_NCHW,                   // compute_data_format
758             out_depths,                    // out_depths
759             {{patch_rows,                  // filter_rows
760               patch_cols,                  // filter_cols
761               patch_depths}},              // filter_depths
762             {{dimensions.dilation_rows,    // dilation_rows
763               dimensions.dilation_cols}},  // dilation_cols
764             {{dimensions.stride_rows,      // stride_rows
765               dimensions.stride_cols}},    // stride_cols
766             {{padding_rows,                // padding_rows
767               padding_cols}},              // padding_cols
768             dtype,                         // tensor datatype
769             device_id,                     // device_id
770         },
771         dnn_activation_mode  // activation_mode
772     };
773 
774     // Launch fused convolution with given parameters and scratch allocator.
775     // Record profile result into `profile_result` if it's not nullptr.
776     const auto launch = [&](se::dnn::AlgorithmConfig algorithm_config,
777                             DnnScratchAllocator* scratch_allocator,
778                             se::dnn::ProfileResult* profile_result) -> bool {
779       return stream
780           ->ThenFusedConvolveWithAlgorithm(
781               input_desc, input_ptr,                     // input
782               /*conv_input_scale=*/1.0,                  // input_scale
783               filter_desc, filter_ptr,                   // filter
784               conv_desc,                                 // conv
785               side_input_ptr, /*side_input_scale=*/0.0,  // side_input
786               bias_desc, bias_ptr,                       // bias
787               dnn_activation_mode,                       // activation
788               output_desc, &output_ptr,                  // output
789               scratch_allocator, algorithm_config, profile_result)
790           .ok();
791     };
792 
793     se::dnn::AlgorithmConfig algorithm_config;
794     if (cudnn_use_autotune) {
795       auto status = FindBestConvolveAlgorithm<T>(
796           conv_parameters, launch, context, stream,
797           [&](absl::Span<const tensorflow::AutotuneResult> results) {
798             LogFusedConvAutotuneResults(
799                 context->op_kernel().def(), input, transformed_filter,
800                 transformed_output, bias, nullptr, stream->parent(), results);
801           },
802           &algorithm_config);
803       OP_REQUIRES_OK(context, status);
804     }
805 
806     DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
807     bool cudnn_launch_status = launch(algorithm_config, &scratch_allocator,
808                                       /*profile_result=*/nullptr);
809     OP_REQUIRES(
810         context, cudnn_launch_status,
811         errors::Internal(absl::Substitute(
812             "cuDNN launch failure: input shape($0) filter shape($1)",
813             input.shape().DebugString(), filter.shape().DebugString())));
814 
815     // Convert the output tensor back from NCHW to NHWC.
816     if (params.data_format == FORMAT_NHWC) {
817       functor::NCHWToNHWC<GPUDevice, T, 4>()(
818           context->eigen_device<GPUDevice>(),
819           const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
820           output->tensor<T, 4>());
821     }
822   }
823 };
824 
825 #endif  // GOOGLE_CUDA
826 
827 template <typename Device, typename T>
828 class FusedConv2DOp : public OpKernel {
829  public:
830   explicit FusedConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
831     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
832 
833     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
834     use_cudnn_ &= CanUseCudnn();
835     cudnn_use_autotune_ = CudnnUseAutotune();
836 
837     // 'fused_ops' and 'num_args' attributes are specified by the Grappler
838     // Remapper optimizer (see grappler/optimizers/remapper.cc).
839 
840     std::vector<string> fused_ops;
841     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
842     OP_REQUIRES(context, !fused_ops.empty(),
843                 errors::InvalidArgument(
844                     "Fused Conv2D must have at least one fused op."));
845 
846     int num_args;
847     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
848 
849     // TODO(ezhulenev): Add support for fusion element-wise op chains defined
850     // at runtime, e.g. Relu+Sqrt+Tanh+etc.
851 
852     // Match combination of fused ops to one of the supported fusions.
853     if (FusedOpsMatchAndSupportedOnDevice(fused_ops, {"BiasAdd"},
854                                           /*cpu_only=*/true)) {
855       fused_computation_ = FusedComputationType::kBiasAdd;
856     } else if (FusedOpsMatchAndSupportedOnDevice(fused_ops, {"BiasAdd", "Relu"},
857                                                  /*cpu_only=*/false)) {
858       fused_computation_ = FusedComputationType::kBiasAddWithRelu;
859     } else if (FusedOpsMatchAndSupportedOnDevice(fused_ops, {"FusedBatchNorm"},
860                                                  /*cpu_only=*/true)) {
861       fused_computation_ = FusedComputationType::kFusedBatchNorm;
862     } else if (FusedOpsMatchAndSupportedOnDevice(fused_ops,
863                                                  {"FusedBatchNorm", "Relu"},
864                                                  /*cpu_only=*/true)) {
865       fused_computation_ = FusedComputationType::kFusedBatchNormWithRelu;
866     } else {
867       OP_REQUIRES(context, false,
868                   errors::Unimplemented("Fusion is not implemented: [",
869                                         absl::StrJoin(fused_ops, ","), "]"));
870     }
871 
872     // Depending on a picked fusion type validate fusion-specific arguments.
873 
874     if (fused_computation_ == FusedComputationType::kBiasAdd ||
875         fused_computation_ == FusedComputationType::kBiasAddWithRelu) {
876       OP_REQUIRES(context, num_args == 1,
877                   errors::InvalidArgument(
878                       "Fused Conv2D must have one extra argument: bias."));
879     }
880 
881     if (fused_computation_ == FusedComputationType::kFusedBatchNorm ||
882         fused_computation_ == FusedComputationType::kFusedBatchNormWithRelu) {
883       OP_REQUIRES(
884           context, num_args == 4,
885           errors::InvalidArgument("Fused FusedBatchNorm must have four extra "
886                                   "arguments: scale, offset, mean, variance."));
887       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_));
888     }
889   }
890 
891   void Compute(OpKernelContext* context) override {
892     // Input tensor is of the following dimensions:
893     // [ batch, in_rows, in_cols, in_depth ]
894     const Tensor& input = context->input(0);
895 
896     // Input filter is of the following dimensions:
897     // [ filter_rows, filter_cols, in_depth, out_depth]
898     const Tensor& filter = context->input(1);
899 
900     Conv2DDimensions dimensions;
901     OP_REQUIRES_OK(context,
902                    ComputeConv2DDimension(params_, input, filter, &dimensions));
903 
904     TensorShape out_shape = ShapeFromFormat(
905         params_.data_format, dimensions.batch, dimensions.out_rows,
906         dimensions.out_cols, dimensions.out_depth);
907 
908     // Output tensor is of the following dimensions:
909     // [ in_batch, out_rows, out_cols, out_depth ]
910     Tensor* output = nullptr;
911     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
912 
913     VLOG(2) << "FusedConv2D: in_depth = " << dimensions.in_depth
914             << ", patch_depth = " << dimensions.patch_depth
915             << ", input_cols = " << dimensions.input_cols
916             << ", filter_cols = " << dimensions.filter_cols
917             << ", input_rows = " << dimensions.input_rows
918             << ", filter_rows = " << dimensions.filter_rows
919             << ", stride_rows = " << dimensions.stride_rows
920             << ", stride_cols = " << dimensions.stride_cols
921             << ", dilation_rows = " << dimensions.dilation_rows
922             << ", dilation_cols = " << dimensions.dilation_cols
923             << ", out_depth = " << dimensions.out_depth;
924 
925     // If there is nothing to compute, return.
926     if (out_shape.num_elements() == 0) {
927       return;
928     }
929 
930     FusedComputationArgs args;
931     args.epsilon = epsilon_;
932 
933     LaunchFusedConv2DOp<Device, T>()(context, use_cudnn_, cudnn_use_autotune_,
934                                      input, filter, fused_computation_, args,
935                                      params_, dimensions, output);
936   }
937 
938  private:
939   bool FusedOpsMatchAndSupportedOnDevice(const std::vector<string>& fused_ops,
940                                          const std::vector<string>& expected,
941                                          bool cpu_only) const {
942     if (std::is_same<Device, GPUDevice>::value && cpu_only) {
943       return false;
944     }
945     return fused_ops == expected;
946   }
947 
948   Conv2DParameters params_;
949   bool use_cudnn_;
950   bool cudnn_use_autotune_;
951 
952   FusedComputationType fused_computation_;
953 
954   float epsilon_;  // Used only in FusedBatchNorm fusion
955 
956   TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DOp);
957 };
958 
959 // Registration of the CPU implementations.
960 #define REGISTER_FUSED_CPU_CONV2D(T)                                  \
961   REGISTER_KERNEL_BUILDER(                                            \
962       Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
963       FusedConv2DOp<CPUDevice, T>);
964 
965 #if GOOGLE_CUDA
966 
967 #define DECLARE_FUNCTOR_GPU_SPEC(T)                                      \
968   template <>                                                            \
969   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
970       const GPUDevice& d, FilterTensorFormat dst_filter_format,          \
971       typename TTypes<T, 4, int>::ConstTensor in,                        \
972       typename TTypes<T, 4, int>::Tensor out);                           \
973   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
974   template <>                                                            \
975   void PadInput<GPUDevice, T, int, 4>::operator()(                       \
976       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
977       const std::array<int, 2>& padding_left,                            \
978       const std::array<int, 2>& padding_right,                           \
979       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
980   extern template struct PadInput<GPUDevice, T, int, 4>
981 
982 // Registration of the GPU implementations.
983 #define REGISTER_FUSED_GPU_CONV2D(T)                                  \
984   REGISTER_KERNEL_BUILDER(                                            \
985       Name("_FusedConv2D").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
986       FusedConv2DOp<GPUDevice, T>);
987 
988 #endif  // GOOGLE_CUDA
989 
990 }  // namespace tensorflow
991 
992 #endif  // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_
993