• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20 
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24 
25 #include "tensorflow/core/kernels/conv_ops.h"
26 
27 #include <string.h>
28 
29 #include <atomic>
30 #include <map>
31 #include <vector>
32 
33 #include "absl/synchronization/blocking_counter.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/bounds_check.h"
36 #include "tensorflow/core/framework/kernel_shape_util.h"
37 #include "tensorflow/core/framework/numeric_op.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/register_types.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/framework/tensor_slice.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/kernels/conv_2d.h"
45 #include "tensorflow/core/kernels/deep_conv2d.h"
46 #include "tensorflow/core/kernels/ops_util.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/array_slice.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
54 #include "tensorflow/core/util/padding.h"
55 #include "tensorflow/core/util/tensor_format.h"
56 #include "tensorflow/core/util/use_cudnn.h"
57 
58 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
59 #include "tensorflow/core/kernels/xsmm_conv2d.h"
60 #endif
61 
62 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
63 #include "tensorflow/core/kernels/conv_ops_gpu.h"
64 #include "tensorflow/core/platform/stream_executor.h"
65 #include "tensorflow/core/protobuf/autotuning.pb.h"
66 #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
67 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
68 #include "tensorflow/core/util/proto/proto_utils.h"
69 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
70 #if GOOGLE_CUDA
71 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
72 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
73 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
74 #endif  // GOOGLE_CUDA
75 
76 namespace tensorflow {
77 
78 typedef Eigen::ThreadPoolDevice CPUDevice;
79 typedef Eigen::GpuDevice GPUDevice;
80 
81 namespace {
82 template <typename Device, typename T>
83 struct LaunchGeneric {
operator ()tensorflow::__anonbabf96d70111::LaunchGeneric84   void operator()(OpKernelContext* ctx, const Tensor& input,
85                   const Tensor& filter, int row_stride, int col_stride,
86                   int row_dilation, int col_dilation, const Padding& padding,
87                   const std::vector<int64>& explicit_paddings, Tensor* output,
88                   TensorFormat data_format) {
89     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
90                                          "supports NHWC tensor format for now.";
91     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
92         col_stride == 1 && (padding == SAME || padding == VALID)) {
93       // For 1x1 kernel, the 2D convolution is reduced to matrix
94       // multiplication.
95       //
96       // TODO(vrv): We should be able to call SpatialConvolution
97       // and it will produce the same result, but doing so
98       // led to NaNs during training.  Using matmul instead for now.
99       int conv_width = 1;  // Width for the convolution step.
100       for (int i = 0; i < 3; ++i) {
101         conv_width *= output->dim_size(i);
102       }
103 
104       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
105       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
106       functor::MatMulConvFunctor<Device, T>()(
107           ctx->eigen_device<Device>(),
108           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
109           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
110           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
111           dim_pair);
112     } else if (filter.dim_size(0) == input.dim_size(1) &&
113                filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
114                col_dilation == 1 && padding == VALID) {
115       // If the input data and filter have the same height/width,
116       // the 2D convolution is reduced to matrix multiplication.
117       const int k =  // Length of reduction dimension.
118           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
119 
120       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
121       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
122       functor::MatMulConvFunctor<Device, T>()(
123           ctx->eigen_device<Device>(),
124           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
125           input.shaped<T, 2>({input.dim_size(0), k}),
126           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
127     } else {
128       if (padding == EXPLICIT) {
129         functor::SpatialConvolution<Device, T>()(
130             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
131             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
132             row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
133             static_cast<int>(explicit_paddings[3]),
134             static_cast<int>(explicit_paddings[4]),
135             static_cast<int>(explicit_paddings[5]));
136       } else {
137         functor::SpatialConvolution<Device, T>()(
138             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
139             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
140             row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
141       }
142     }
143   }
144 };
145 
146 // Compute grouped 2D convolutions on CPU. Unlike grouped convolution
147 // implementation in cuDNN this is faaaaaar from optimal and needs more work
148 // to deliver competitive performance. Currently it exists to close the feature
149 // parity gap between convolution operations on different devices.
150 template <typename T>
151 struct LaunchGrouped {
operator ()tensorflow::__anonbabf96d70111::LaunchGrouped152   void operator()(OpKernelContext* ctx, const Tensor& input,
153                   const Tensor& filter, int row_stride, int col_stride,
154                   int row_dilation, int col_dilation, const Padding& padding,
155                   const std::vector<int64>& explicit_paddings, Tensor* output,
156                   TensorFormat data_format) {
157     DCHECK(data_format == FORMAT_NHWC)
158         << "Grouped conv implementation only "
159            "supports NHWC tensor format for now.";
160 
161     const int64_t in_depth = input.dim_size(3);
162     const int64_t patch_depth = filter.dim_size(2);
163     const int64_t num_groups = in_depth / patch_depth;
164 
165     // Shuffle input/filter tensors to have group as a leading dimension.
166     std::array<int64, 5> shuffle({3, 0, 1, 2, 4});
167 
168     // Compute pre shuffle dimemnsions.
169     auto pre_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
170       return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2),
171               num_groups, tensor.dim_size(3) / num_groups};
172     };
173 
174     // Compute post shuffle dimemnsions.
175     auto post_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
176       return {num_groups, tensor.dim_size(0), tensor.dim_size(1),
177               tensor.dim_size(2), tensor.dim_size(3) / num_groups};
178     };
179 
180     auto& device = ctx->eigen_device<CPUDevice>();
181 
182     absl::BlockingCounter shuffles_completed(2);
183     auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
184 
185     // Shuffle input into temporary tensor.
186     Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
187     input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
188         input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
189 
190     // Shuffle filter into temporary tensor.
191     Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
192     filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
193         filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
194 
195     // Wait for the completion of input/filter shuffles.
196     shuffles_completed.Wait();
197 
198     // Write group convolution results into temporary output tensor.
199     Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
200 
201     for (int64_t i = 0; i < num_groups; ++i) {
202       // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
203       // will lead to deadlock, SpatialConvolution has to use async Eigen
204       // assignment). This requires small changes to Eigen to support async
205       // exeuction for tensor chipping operation.
206 
207       // TODO(ezhulenev): Grouped convolution should also support 1x1 filter
208       // optimization.
209 
210       auto input_slice = input_shuffled.tensor<T, 5>().template chip<0>(i);
211       auto filter_slice = filter_shuffled.tensor<T, 5>().template chip<0>(i);
212       auto output_slice = output_shuffled.tensor<T, 5>().template chip<0>(i);
213 
214       if (padding == EXPLICIT) {
215         functor::SpatialConvolution<CPUDevice, T>()(
216             ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
217             filter_slice, row_stride, col_stride, row_dilation, col_dilation,
218             static_cast<int>(explicit_paddings[2]),
219             static_cast<int>(explicit_paddings[3]),
220             static_cast<int>(explicit_paddings[4]),
221             static_cast<int>(explicit_paddings[5]));
222       } else {
223         functor::SpatialConvolution<CPUDevice, T>()(
224             ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
225             filter_slice, row_stride, col_stride, row_dilation, col_dilation,
226             BrainPadding2EigenPadding(padding));
227       }
228     }
229 
230     // Shuffle temporary output back into pre-shuffled shape.
231     std::array<int64, 5> rev_shuffle({1, 2, 3, 0, 4});
232     output->shaped<T, 5>(pre_shuffle(*output)).device(device) =
233         output_shuffled.tensor<T, 5>().shuffle(rev_shuffle);
234   }
235 };
236 
237 }  // namespace
238 
239 template <typename T>
240 struct LaunchConv2DOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DOp241   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
242                   const Tensor& input, const Tensor& filter, int row_dilation,
243                   int col_dilation, int row_stride, int col_stride,
244                   const Padding& padding,
245                   const std::vector<int64>& explicit_paddings, Tensor* output,
246                   TensorFormat data_format) {
247     if (data_format != FORMAT_NHWC) {
248       ctx->SetStatus(errors::Unimplemented(
249           "The Conv2D op currently only supports the NHWC tensor format on the "
250           "CPU. The op was given the format: ",
251           ToString(data_format)));
252       return;
253     }
254 
255     for (int64_t explicit_padding : explicit_paddings) {
256       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
257         ctx->SetStatus(errors::InvalidArgument("filter too large"));
258         return;
259       }
260     }
261 
262     const int64_t in_depth = input.dim_size(3);
263     const int64_t out_depth = output->dim_size(3);
264     const int64_t patch_depth = filter.dim_size(2);
265 
266     if (patch_depth <= 0) {
267       ctx->SetStatus(errors::InvalidArgument(
268           "filter depth must be stricly positive, got ", patch_depth));
269       return;
270     }
271     if (in_depth % patch_depth != 0) {
272       ctx->SetStatus(errors::InvalidArgument(
273           "input depth must be evenly divisible by filter depth: ", in_depth,
274           " vs ", patch_depth));
275       return;
276     }
277     if (filter.NumElements() <= 0) {
278       ctx->SetStatus(
279           errors::InvalidArgument("filter must not have zero elements "
280                                   "(i.e. all dimensions must be non-zero)"));
281       return;
282     }
283 
284     const int64_t num_groups = in_depth / patch_depth;
285     if (num_groups <= 0) {
286       ctx->SetStatus(errors::InvalidArgument(
287           "number of groups must be stricly positive, got ", num_groups));
288       return;
289     }
290     if (out_depth % num_groups != 0 || out_depth < num_groups) {
291       ctx->SetStatus(errors::InvalidArgument(
292           "output depth must be evenly divisible by number of groups: ",
293           out_depth, " vs ", num_groups));
294       return;
295     }
296 
297     if (in_depth != patch_depth) {
298       LaunchGrouped<T>()(ctx, input, filter, row_stride, col_stride,
299                          row_dilation, col_dilation, padding, explicit_paddings,
300                          output, data_format);
301     } else {
302       LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
303                                     row_dilation, col_dilation, padding,
304                                     explicit_paddings, output, data_format);
305     }
306   }
307 };
308 
309 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
310 template <>
311 struct LaunchConv2DOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DOp312   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
313                   const Tensor& input, const Tensor& filter, int row_dilation,
314                   int col_dilation, int row_stride, int col_stride,
315                   const Padding& padding,
316                   const std::vector<int64>& explicit_paddings, Tensor* output,
317                   TensorFormat data_format) {
318     if (data_format != FORMAT_NHWC) {
319       ctx->SetStatus(
320           errors::Unimplemented("The Conv2D op currently only supports the "
321                                 "NHWC tensor format for integer types. "
322                                 "The op was given the format: ",
323                                 ToString(data_format)));
324       return;
325     }
326     const int64_t in_depth = GetTensorDim(input, data_format, 'C');
327     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
328                 errors::Unimplemented(
329                     "The Conv2D op currently does not support grouped "
330                     "convolutions for integer types. A grouped convolution was "
331                     "attempted to be run because the input depth of ",
332                     in_depth, " does not match the filter input depth of ",
333                     filter.dim_size(2)));
334     OP_REQUIRES(
335         ctx, filter.NumElements() > 0,
336         errors::InvalidArgument("filter must not have zero elements "
337                                 "(i.e. all dimensions must be non-zero)"));
338 
339     for (int64_t explicit_padding : explicit_paddings) {
340       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
341         ctx->SetStatus(errors::InvalidArgument("filter too large"));
342         return;
343       }
344     }
345     LaunchGeneric<GPUDevice, int32>()(
346         ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation,
347         padding, explicit_paddings, output, data_format);
348   }
349 };
350 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
351 
352 template <typename Device, typename T>
353 class LaunchDeepConvOp {
354  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int,int,int,int,int,int,Tensor *,TensorFormat)355   static bool Run(OpKernelContext* ctx, const Tensor& input,
356                   const Tensor& filter, int batch, int input_rows,
357                   int input_cols, int in_depth, int filter_rows,
358                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
359                   int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
360                   int /*dilation_cols*/, int /*stride_rows*/,
361                   int /*stride_cols*/, Tensor* /*output*/,
362                   TensorFormat /*data_format*/) {
363     return false;
364   }
365 };
366 
367 // Conditionally launches DeepConv operation based on convolution parameters.
368 template <>
369 class LaunchDeepConvOp<CPUDevice, float> {
370  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)371   static bool Run(OpKernelContext* ctx, const Tensor& input,
372                   const Tensor& filter, int batch, int input_rows,
373                   int input_cols, int in_depth, int filter_rows,
374                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
375                   int out_cols, int out_depth, int dilation_rows,
376                   int dilation_cols, int stride_rows, int stride_cols,
377                   Tensor* output, TensorFormat data_format) {
378     if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
379         dilation_cols != 1 ||
380         !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
381                           in_depth, out_depth, out_rows, out_cols)) {
382       return false;
383     }
384 
385     Conv2DArgs args;
386     args.batch = batch;
387     args.in_rows = input_rows;
388     args.in_cols = input_cols;
389     args.in_depth = in_depth;
390     args.filter_rows = filter_rows;
391     args.filter_cols = filter_cols;
392     args.pad_rows = pad_rows;
393     args.pad_cols = pad_cols;
394     args.out_rows = out_rows;
395     args.out_cols = out_cols;
396     args.out_depth = out_depth;
397 
398     auto input_ptr = input.template flat<float>().data();
399     auto filter_ptr = filter.template flat<float>().data();
400     auto output_ptr = output->template flat<float>().data();
401 
402     functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
403                                             output_ptr);
404     return true;
405   }
406 };
407 
408 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
409 template <typename Device, typename T>
410 class LaunchXsmmConvOp {
411  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int stride_rows,int stride_cols,int dilation_rows,int dilation_cols,Tensor * output,TensorFormat data_format)412   static bool Run(OpKernelContext* ctx, const Tensor& input,
413                   const Tensor& filter, int batch, int input_rows,
414                   int input_cols, int in_depth, int filter_rows,
415                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
416                   int out_cols, int out_depth, int stride_rows, int stride_cols,
417                   int dilation_rows, int dilation_cols, Tensor* output,
418                   TensorFormat data_format) {
419     return false;
420   }
421 };
422 
423 template <>
424 class LaunchXsmmConvOp<CPUDevice, float> {
425  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)426   static bool Run(OpKernelContext* ctx, const Tensor& input,
427                   const Tensor& filter, int batch, int input_rows,
428                   int input_cols, int in_depth, int filter_rows,
429                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
430                   int out_cols, int out_depth, int dilation_rows,
431                   int dilation_cols, int stride_rows, int stride_cols,
432                   Tensor* output, TensorFormat data_format) {
433     auto num_threads =
434         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
435     // See libxsmm_dnn.h for this struct definition.
436     libxsmm_dnn_conv_desc desc;
437     desc.N = batch;
438     desc.C = in_depth;
439     desc.H = input_rows;
440     desc.W = input_cols;
441     desc.K = out_depth;
442     desc.R = filter_rows;
443     desc.S = filter_cols;
444     desc.u = stride_rows;
445     desc.v = stride_cols;
446     desc.pad_h = pad_rows;
447     desc.pad_w = pad_cols;
448     desc.pad_h_in = 0;
449     desc.pad_w_in = 0;
450     desc.pad_h_out = 0;
451     desc.pad_w_out = 0;
452     desc.threads = num_threads;
453     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
454     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
455     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
456     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
457     desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
458     desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
459     desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
460     if (dilation_rows != 1 || dilation_cols != 1 ||
461         !CanUseXsmmConv2D(desc, data_format)) {
462       return false;
463     }
464 
465     auto input_ptr = input.template flat<float>().data();
466     auto filter_ptr = filter.template flat<float>().data();
467     auto output_ptr = output->template flat<float>().data();
468 
469     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
470         ctx, desc, input_ptr, filter_ptr, output_ptr);
471     return success;
472   }
473 };
474 #endif
475 
476 #define TF_REQUIRES(EXP, STATUS)                \
477   do {                                          \
478     if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
479   } while (false)
480 
InitConv2DParameters(const OpKernelConstruction * context,Conv2DParameters * params)481 Status InitConv2DParameters(const OpKernelConstruction* context,
482                             Conv2DParameters* params) {
483   TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
484   TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
485   TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
486   if (context->HasAttr("explicit_paddings")) {
487     TF_RETURN_IF_ERROR(
488         context->GetAttr("explicit_paddings", &params->explicit_paddings));
489   }
490   string data_format_string;
491   TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
492   TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
493               errors::InvalidArgument("Invalid data format"));
494 
495   const auto& strides = params->strides;
496   const auto& dilations = params->dilations;
497   const auto& data_format = params->data_format;
498 
499   TF_REQUIRES(dilations.size() == 4,
500               errors::InvalidArgument("Sliding window dilations field must "
501                                       "specify 4 dimensions"));
502   TF_REQUIRES(strides.size() == 4,
503               errors::InvalidArgument("Sliding window strides field must "
504                                       "specify 4 dimensions"));
505   const int64_t stride_n = GetTensorDim(strides, data_format, 'N');
506   const int64_t stride_c = GetTensorDim(strides, data_format, 'C');
507   const int64_t stride_h = GetTensorDim(strides, data_format, 'H');
508   const int64_t stride_w = GetTensorDim(strides, data_format, 'W');
509   TF_REQUIRES(
510       stride_n == 1 && stride_c == 1,
511       errors::Unimplemented("Current implementation does not yet support "
512                             "strides in the batch and depth dimensions."));
513   TF_REQUIRES(stride_h > 0 && stride_w > 0,
514               errors::InvalidArgument(
515                   "Row and column strides should be larger than 0."));
516 
517   const int64_t dilation_n = GetTensorDim(dilations, data_format, 'N');
518   const int64_t dilation_c = GetTensorDim(dilations, data_format, 'C');
519   const int64_t dilation_h = GetTensorDim(dilations, data_format, 'H');
520   const int64_t dilation_w = GetTensorDim(dilations, data_format, 'W');
521   TF_REQUIRES(
522       dilation_n == 1 && dilation_c == 1,
523       errors::Unimplemented("Current implementation does not yet support "
524                             "dilations in the batch and depth dimensions."));
525   TF_REQUIRES(
526       dilation_h > 0 && dilation_w > 0,
527       errors::InvalidArgument("Dilated rates should be larger than 0."));
528 
529   TF_RETURN_IF_ERROR(CheckValidPadding(params->padding,
530                                        params->explicit_paddings,
531                                        /*num_dims=*/4, data_format));
532 
533   return Status::OK();
534 }
535 
ComputeConv2DDimension(const Conv2DParameters & params,const Tensor & input,const Tensor & filter,Conv2DDimensions * dimensions)536 Status ComputeConv2DDimension(const Conv2DParameters& params,
537                               const Tensor& input, const Tensor& filter,
538                               Conv2DDimensions* dimensions) {
539   // Check that 2D convolution input and filter have exactly 4 dimensions.
540   TF_REQUIRES(input.dims() == 4,
541               errors::InvalidArgument("input must be 4-dimensional",
542                                       input.shape().DebugString()));
543   TF_REQUIRES(filter.dims() == 4,
544               errors::InvalidArgument("filter must be 4-dimensional: ",
545                                       filter.shape().DebugString()));
546   for (int i = 0; i < 3; i++) {
547     TF_REQUIRES(
548         FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
549         errors::InvalidArgument("filter too large"));
550   }
551 
552   // The last dimension for input is in_depth. Check that it is the same as the
553   // filter's in_depth or it is evenly divisible by filter's in_depth.
554   const int64_t in_depth_raw = GetTensorDim(input, params.data_format, 'C');
555   const int64_t patch_depth_raw = filter.dim_size(2);
556   TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
557               errors::InvalidArgument("Input depth too large"));
558   TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
559               errors::InvalidArgument("Patch depth too large"));
560   const int in_depth = static_cast<int>(in_depth_raw);
561   const int patch_depth = static_cast<int>(patch_depth_raw);
562   TF_REQUIRES(patch_depth > 0,
563               errors::InvalidArgument(
564                   "filter depth must be stricly positive, got ", patch_depth));
565   TF_REQUIRES(in_depth % patch_depth == 0,
566               errors::InvalidArgument(
567                   "input depth must be evenly divisible by filter depth: ",
568                   in_depth, " vs ", patch_depth));
569 
570   // The last dimension for filter is out_depth.
571   const int out_depth = static_cast<int>(filter.dim_size(3));
572 
573   // The second dimension for input is rows/height.
574   // The first dimension for filter is rows/height.
575   const int64_t input_rows_raw = GetTensorDim(input, params.data_format, 'H');
576   TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
577               errors::InvalidArgument("Input rows too large"));
578   const int input_rows = static_cast<int>(input_rows_raw);
579   const int filter_rows = static_cast<int>(filter.dim_size(0));
580 
581   // The third dimension for input is columns/width.
582   // The second dimension for filter is columns/width.
583   const int64_t input_cols_raw = GetTensorDim(input, params.data_format, 'W');
584   TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
585               errors::InvalidArgument("Input cols too large"));
586   const int input_cols = static_cast<int>(input_cols_raw);
587   const int filter_cols = static_cast<int>(filter.dim_size(1));
588 
589   // The first dimension for input is batch.
590   const int64_t batch_raw = GetTensorDim(input, params.data_format, 'N');
591   TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
592               errors::InvalidArgument("batch is too large"));
593   const int batch = static_cast<int>(batch_raw);
594 
595   // Take the stride and dilation from the second and third dimensions only (we
596   // do not support striding or dilation on the batch or depth dimension).
597   const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
598   const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
599   const int dilation_rows =
600       GetTensorDim(params.dilations, params.data_format, 'H');
601   const int dilation_cols =
602       GetTensorDim(params.dilations, params.data_format, 'W');
603 
604   int64_t pad_rows_before, pad_rows_after, pad_cols_before, pad_cols_after;
605   if (params.padding == Padding::EXPLICIT) {
606     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'H',
607                              &pad_rows_before, &pad_rows_after);
608     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'W',
609                              &pad_cols_before, &pad_cols_after);
610   }
611 
612   // Compute windowed output sizes for rows and columns.
613   int64_t out_rows = 0, out_cols = 0;
614   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
615       input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
616       &out_rows, &pad_rows_before, &pad_rows_after));
617   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
618       input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
619       &out_cols, &pad_cols_before, &pad_cols_after));
620 
621   dimensions->batch = batch;
622   dimensions->input_rows = input_rows;
623   dimensions->input_cols = input_cols;
624   dimensions->in_depth = in_depth;
625   dimensions->filter_rows = filter_rows;
626   dimensions->filter_cols = filter_cols;
627   dimensions->patch_depth = patch_depth;
628   dimensions->out_depth = out_depth;
629   dimensions->stride_rows = stride_rows;
630   dimensions->stride_cols = stride_cols;
631   dimensions->dilation_rows = dilation_rows;
632   dimensions->dilation_cols = dilation_cols;
633   dimensions->out_rows = out_rows;
634   dimensions->out_cols = out_cols;
635   dimensions->pad_rows_before = pad_rows_before;
636   dimensions->pad_rows_after = pad_rows_after;
637   dimensions->pad_cols_before = pad_cols_before;
638   dimensions->pad_cols_after = pad_cols_after;
639 
640   return Status::OK();
641 }
642 
643 #undef TF_REQUIRES
644 
645 template <typename Device, typename T>
646 class Conv2DOp : public BinaryOp<T> {
647  public:
Conv2DOp(OpKernelConstruction * context)648   explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
649     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
650 
651     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
652     cudnn_use_autotune_ = CudnnUseAutotune();
653   }
654 
Compute(OpKernelContext * context)655   void Compute(OpKernelContext* context) override {
656     // Input tensor is of the following dimensions:
657     // [ batch, in_rows, in_cols, in_depth ]
658     const Tensor& input = context->input(0);
659 
660     // Input filter is of the following dimensions:
661     // [ filter_rows, filter_cols, in_depth, out_depth]
662     const Tensor& filter = context->input(1);
663 
664     Conv2DDimensions dimensions;
665     OP_REQUIRES_OK(context,
666                    ComputeConv2DDimension(params_, input, filter, &dimensions));
667 
668     TensorShape out_shape = ShapeFromFormat(
669         params_.data_format, dimensions.batch, dimensions.out_rows,
670         dimensions.out_cols, dimensions.out_depth);
671 
672     // Output tensor is of the following dimensions:
673     // [ in_batch, out_rows, out_cols, out_depth ]
674     Tensor* output = nullptr;
675     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
676 
677     VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
678             << ", patch_depth = " << dimensions.patch_depth
679             << ", input_cols = " << dimensions.input_cols
680             << ", filter_cols = " << dimensions.filter_cols
681             << ", input_rows = " << dimensions.input_rows
682             << ", filter_rows = " << dimensions.filter_rows
683             << ", stride_rows = " << dimensions.stride_rows
684             << ", stride_cols = " << dimensions.stride_cols
685             << ", dilation_rows = " << dimensions.dilation_rows
686             << ", dilation_cols = " << dimensions.dilation_cols
687             << ", out_depth = " << dimensions.out_depth;
688 
689     // If there is nothing to compute, return.
690     if (out_shape.num_elements() == 0) {
691       return;
692     }
693 
694 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
695     if (params_.padding != EXPLICIT &&
696         LaunchXsmmConvOp<Device, T>::Run(
697             context, input, filter, dimensions.batch, dimensions.input_rows,
698             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
699             dimensions.filter_cols, dimensions.pad_rows_before,
700             dimensions.pad_cols_before, dimensions.out_rows,
701             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
702             dimensions.dilation_cols, dimensions.stride_rows,
703             dimensions.stride_cols, output, params_.data_format)) {
704       return;
705     }
706 #endif
707 
708     if (params_.padding != EXPLICIT &&
709         LaunchDeepConvOp<Device, T>::Run(
710             context, input, filter, dimensions.batch, dimensions.input_rows,
711             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
712             dimensions.filter_cols, dimensions.pad_rows_before,
713             dimensions.pad_cols_before, dimensions.out_rows,
714             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
715             dimensions.dilation_cols, dimensions.stride_rows,
716             dimensions.stride_cols, output, params_.data_format)) {
717       return;
718     }
719 
720     launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
721               dimensions.dilation_rows, dimensions.dilation_cols,
722               dimensions.stride_rows, dimensions.stride_cols, params_.padding,
723               params_.explicit_paddings, output, params_.data_format);
724   }
725 
726  private:
727   Conv2DParameters params_;
728   bool use_cudnn_;
729   bool cudnn_use_autotune_;
730 
731   LaunchConv2DOp<Device, T> launcher_;
732 
733   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
734 };
735 
736 #define REGISTER_CPU(T)                                         \
737   REGISTER_KERNEL_BUILDER(                                      \
738       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
739       Conv2DOp<CPUDevice, T>);
740 
741 // If we're using the alternative GEMM-based implementation of Conv2D for the
742 // CPU implementation, don't register this EigenTensor-based version.
743 #if !defined(USE_GEMM_FOR_CONV)
744 TF_CALL_half(REGISTER_CPU);
745 TF_CALL_float(REGISTER_CPU);
746 TF_CALL_double(REGISTER_CPU);
747 TF_CALL_int32(REGISTER_CPU);
748 #endif  // USE_GEMM_FOR_CONV
749 
750 // To be used inside depthwise_conv_op.cc.
751 template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
752 template struct LaunchConv2DOp<CPUDevice, float>;
753 template struct LaunchConv2DOp<CPUDevice, double>;
754 
755 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
756 
GetDnnWorkspaceLimit(const string & envvar_in_mb,int64_t default_value_in_bytes)757 int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
758                            int64_t default_value_in_bytes) {
759   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
760   if (workspace_limit_in_mb_str != nullptr &&
761       strcmp(workspace_limit_in_mb_str, "") != 0) {
762     int64_t scratch_limit_in_mb = -1;
763     if (strings::safe_strto64(workspace_limit_in_mb_str,
764                               &scratch_limit_in_mb)) {
765       return scratch_limit_in_mb * (1 << 20);
766     } else {
767       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
768                    << workspace_limit_in_mb_str;
769     }
770   }
771   return default_value_in_bytes;
772 }
773 
774 
775 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & input_param,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * output,TensorFormat data_format)776 void LaunchConv2DOp<GPUDevice, T>::operator()(
777     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
778     const Tensor& input_param, const Tensor& filter, int row_dilation,
779     int col_dilation, int row_stride, int col_stride, const Padding& padding,
780     const std::vector<int64>& explicit_paddings, Tensor* output,
781     TensorFormat data_format) {
782   using se::dnn::AlgorithmConfig;
783   using se::dnn::AlgorithmDesc;
784   using se::dnn::ProfileResult;
785   auto* stream = ctx->op_device_context()->stream();
786   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
787 
788   if (!use_cudnn) {
789     ctx->SetStatus(
790         errors::Unimplemented("Conv2D for GPU is not currently supported "
791                               "without cudnn"));
792     return;
793   }
794 
795   Tensor input = input_param;
796   const int64_t in_batch = GetTensorDim(input, data_format, 'N');
797   int64_t in_rows = GetTensorDim(input, data_format, 'H');
798   int64_t in_cols = GetTensorDim(input, data_format, 'W');
799   const int64_t in_depths = GetTensorDim(input, data_format, 'C');
800   const int64_t patch_rows = filter.dim_size(0);
801   const int64_t patch_cols = filter.dim_size(1);
802   const int64_t patch_depths = filter.dim_size(2);
803 
804   OP_REQUIRES(
805       ctx, filter.NumElements() > 0,
806       errors::InvalidArgument("filter must not have zero elements "
807                               "(i.e. all dimensions must be non-zero)"));
808 
809   // If the filter in-depth (patch_depths) is 1 and smaller than the input
810   // depth, it's a depthwise convolution. More generally, if the filter in-depth
811   // divides but is smaller than the input depth, it is a grouped convolution.
812   bool is_grouped_convolution = patch_depths != in_depths;
813   if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
814       row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
815       col_stride == 1 && data_format == FORMAT_NHWC &&
816       (padding == VALID || padding == SAME)) {
817     // 1x1 filter, so call cublas directly.
818     const uint64 m = in_batch * in_rows * in_cols;
819     const uint64 k = patch_depths;
820     const uint64 n = filter.dim_size(3);
821 
822     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
823                                 input.template flat<T>().size());
824     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
825                                 filter.template flat<T>().size());
826     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
827                                 output->template flat<T>().size());
828 
829     auto no_transpose = se::blas::Transpose::kNoTranspose;
830     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m,
831                                              k, b_ptr, n, a_ptr, k, &c_ptr, n));
832     return;
833   } else if (patch_rows == in_rows && patch_cols == in_cols &&
834              !is_grouped_convolution && row_dilation == 1 &&
835              col_dilation == 1 && padding == VALID &&
836              data_format == FORMAT_NHWC) {
837     // The input data and filter have the same height/width, so call cublas
838     // directly.
839     const uint64 m = in_batch;
840     const uint64 k = patch_rows * patch_cols * patch_depths;
841     const uint64 n = filter.dim_size(3);
842 
843     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
844                                 input.template flat<T>().size());
845     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
846                                 filter.template flat<T>().size());
847     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
848                                 output->template flat<T>().size());
849 
850     auto no_transpose = se::blas::Transpose::kNoTranspose;
851     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m,
852                                              k, b_ptr, n, a_ptr, k, &c_ptr, n));
853     return;
854   }
855 
856 #if GOOGLE_CUDA
857   // Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
858   // in NHWC data layout. In all other configurations it's more efficient to
859   // run computation in NCHW data format.
860   const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
861                                stream->GetCudaComputeCapability().IsAtLeast(
862                                    se::CudaComputeCapability::VOLTA);
863 #else
864   // fast NHWC implementation is a CUDA only feature
865   const bool compute_in_nhwc = false;
866 #endif
867 
868   // We only do one directional conversion: NHWC->NCHW. We never convert in the
869   // other direction. Grappler layout optimizer selects preferred layout and
870   // adds necessary annotations to the graph.
871   // TODO(ezhulenev): Convert in other direction for fp16?
872   const TensorFormat compute_data_format =
873       (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
874                                                       : FORMAT_NCHW;
875 
876   VLOG(3) << "Compute Conv2D with cuDNN:"
877           << " data_format=" << ToString(data_format)
878           << " compute_data_format=" << ToString(compute_data_format);
879 
880   const int64_t out_batch = GetTensorDim(*output, data_format, 'N');
881   const int64_t out_rows = GetTensorDim(*output, data_format, 'H');
882   const int64_t out_cols = GetTensorDim(*output, data_format, 'W');
883   const int64_t out_depths = GetTensorDim(*output, data_format, 'C');
884   int64_t padding_top = -1, padding_bottom = -1;
885   int64_t padding_left = -1, padding_right = -1;
886   if (padding == EXPLICIT) {
887     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
888                              &padding_bottom);
889     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
890                              &padding_right);
891   }
892   int64_t out_rows_check, out_cols_check;
893   Status status = GetWindowedOutputSizeVerboseV2(
894       in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check,
895       &padding_top, &padding_bottom);
896   // The status is guaranteed to be OK because we checked the output and padding
897   // was valid earlier.
898   TF_CHECK_OK(status);
899   DCHECK_EQ(out_rows, out_rows_check);
900   status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation,
901                                           col_stride, padding, &out_cols_check,
902                                           &padding_left, &padding_right);
903   TF_CHECK_OK(status);
904   DCHECK_EQ(out_cols, out_cols_check);
905 
906   const int64_t common_padding_rows = std::min(padding_top, padding_bottom);
907   const int64_t common_padding_cols = std::min(padding_left, padding_right);
908   if (padding_top != padding_bottom || padding_left != padding_right) {
909     // cuDNN only supports padding the same amount on the left and right sides,
910     // and on the top and bottom sides. So we manually create a new padded
911     // input tensor such that we can pass it to cuDNN.
912     VLOG(4) << "Pad input tensor:"
913             << " padding_top=" << padding_top
914             << " padding_bottom=" << padding_bottom
915             << " padding_left=" << padding_left
916             << " padding_right=" << padding_right;
917 
918     // TODO(reedwm): In some cases, we can avoid an allocation even if the two
919     // padding sides are different. For example, if the input is 2x2, the filter
920     // is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the result is
921     // equivalent to as if the padding is (1, 1, 1, 1). Changing the padding in
922     // such a way would allow us to avoid the allocation.
923     Tensor transformed_input;
924     const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top);
925     const int64_t padding_cols_diff = std::abs(padding_right - padding_left);
926     const int64_t new_in_rows = in_rows + padding_rows_diff;
927     const int64_t new_in_cols = in_cols + padding_cols_diff;
928     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
929                             DataTypeToEnum<T>::value,
930                             ShapeFromFormat(data_format, in_batch, new_in_rows,
931                                             new_in_cols, in_depths),
932                             &transformed_input));
933 
934     const int64_t input_pad_top = padding_top - common_padding_rows;
935     const int64_t input_pad_bottom = padding_bottom - common_padding_rows;
936     const int64_t input_pad_left = padding_left - common_padding_cols;
937     const int64_t input_pad_right = padding_right - common_padding_cols;
938     bool in_bounds =
939         FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
940         FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
941         FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
942         FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
943     if (!in_bounds) {
944       ctx->SetStatus(errors::InvalidArgument("Padding is too large."));
945       return;
946     }
947     functor::PadInput<GPUDevice, T, int, 4>()(
948         ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
949         {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
950         {{static_cast<int>(input_pad_bottom),
951           static_cast<int>(input_pad_right)}},
952         To32Bit(transformed_input.tensor<T, 4>()), data_format, T{});
953 
954     input = transformed_input;
955     in_rows = new_in_rows;
956     in_cols = new_in_cols;
957   }
958 
959   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
960     VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
961 
962     TensorShape nchw_shape =
963         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
964     if (in_depths > 1) {
965       Tensor transformed_input;
966       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
967                                              nchw_shape, &transformed_input));
968       functor::NHWCToNCHW<GPUDevice, T, 4>()(
969           ctx->eigen_device<GPUDevice>(),
970           const_cast<const Tensor&>(input).tensor<T, 4>(),
971           transformed_input.tensor<T, 4>());
972       input = transformed_input;
973     } else {
974       // If depth <= 1, then just reshape.
975       CHECK(input.CopyFrom(input, nchw_shape));
976     }
977   } else {
978     CHECK(data_format == compute_data_format)  // Crash OK
979         << "Illegal data and compute format pair:"
980         << " data_format=" << ToString(data_format)
981         << " compute_data_format=" << ToString(compute_data_format);
982   }
983 
984   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
985       << "Negative row or col paddings: (" << common_padding_rows << ", "
986       << common_padding_cols << ")";
987 
988   constexpr auto kComputeInNHWC =
989       std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
990                       se::dnn::FilterLayout::kOutputYXInput);
991   constexpr auto kComputeInNCHW =
992       std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
993                       se::dnn::FilterLayout::kOutputInputYX);
994 
995   se::dnn::DataLayout compute_data_layout;
996   se::dnn::FilterLayout filter_layout;
997 
998   std::tie(compute_data_layout, filter_layout) =
999       compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1000 
1001   se::dnn::BatchDescriptor input_desc;
1002   input_desc.set_count(in_batch)
1003       .set_feature_map_count(in_depths)
1004       .set_height(in_rows)
1005       .set_width(in_cols)
1006       .set_layout(compute_data_layout);
1007   se::dnn::BatchDescriptor output_desc;
1008   output_desc.set_count(out_batch)
1009       .set_height(out_rows)
1010       .set_width(out_cols)
1011       .set_feature_map_count(out_depths)
1012       .set_layout(compute_data_layout);
1013   se::dnn::FilterDescriptor filter_desc;
1014   filter_desc.set_input_filter_height(patch_rows)
1015       .set_input_filter_width(patch_cols)
1016       .set_input_feature_map_count(patch_depths)
1017       .set_output_feature_map_count(filter.dim_size(3))
1018       .set_layout(filter_layout);
1019   se::dnn::ConvolutionDescriptor conv_desc;
1020   conv_desc.set_vertical_dilation_rate(row_dilation)
1021       .set_horizontal_dilation_rate(col_dilation)
1022       .set_vertical_filter_stride(row_stride)
1023       .set_horizontal_filter_stride(col_stride)
1024       .set_zero_padding_height(common_padding_rows)
1025       .set_zero_padding_width(common_padding_cols)
1026       .set_group_count(in_depths / patch_depths);
1027 
1028   Tensor transformed_filter;
1029 
1030   const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
1031     VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
1032             << " to " << ToString(dst_format);
1033 
1034     TensorShape dst_shape =
1035         dst_format == FORMAT_OIHW
1036             ? TensorShape({filter.dim_size(3), filter.dim_size(2),
1037                            filter.dim_size(0), filter.dim_size(1)})
1038             : TensorShape({filter.dim_size(3), filter.dim_size(0),
1039                            filter.dim_size(1), filter.dim_size(2)});
1040 
1041     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1042                                           &transformed_filter));
1043     functor::TransformFilter<GPUDevice, T, int, 4>()(
1044         ctx->eigen_device<GPUDevice>(), dst_format,
1045         To32Bit(filter.tensor<T, 4>()),
1046         To32Bit(transformed_filter.tensor<T, 4>()));
1047 
1048     return Status::OK();
1049   };
1050 
1051   if (compute_data_format == FORMAT_NCHW) {
1052     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
1053   } else if (compute_data_format == FORMAT_NHWC) {
1054     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
1055   } else {
1056     ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
1057                                            ToString(compute_data_format)));
1058     return;
1059   }
1060 
1061   Tensor transformed_output;
1062   if (data_format != compute_data_format) {
1063     VLOG(4) << "Allocate temporary memory for output in compute data format";
1064     OP_REQUIRES_OK(
1065         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1066                                 ShapeFromFormat(compute_data_format, out_batch,
1067                                                 out_rows, out_cols, out_depths),
1068                                 &transformed_output));
1069   } else {
1070     transformed_output = *output;
1071   }
1072 
1073   auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
1074                                   input.template flat<T>().size());
1075   auto filter_ptr =
1076       AsDeviceMemory(transformed_filter.template flat<T>().data(),
1077                      transformed_filter.template flat<T>().size());
1078   auto output_ptr =
1079       AsDeviceMemory(transformed_output.template flat<T>().data(),
1080                      transformed_output.template flat<T>().size());
1081 
1082   static int64_t ConvolveScratchSize = GetDnnWorkspaceLimit(
1083       // default value is in bytes despite the name of the environment variable
1084       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
1085   );
1086 
1087   int device_id = stream->parent()->device_ordinal();
1088   DataType dtype = input.dtype();
1089   ConvParameters conv_parameters = {in_batch,             // batch
1090                                     in_depths,            // in_depths
1091                                     {{in_rows,            // in_rows
1092                                       in_cols}},          // in_cols
1093                                     compute_data_format,  // compute_data_format
1094                                     out_depths,           // out_depths
1095                                     {{patch_rows,         // filter_rows
1096                                       patch_cols,         // filter_cols
1097                                       patch_depths}},     // filter_depths
1098                                     {{row_dilation,       // dilation_rows
1099                                       col_dilation}},     // dilation_cols
1100                                     {{row_stride,         // stride_rows
1101                                       col_stride}},       // stride_cols
1102                                     {{common_padding_rows,    // padding_rows
1103                                       common_padding_cols}},  // padding_cols
1104                                     dtype,                    // tensor datatype
1105                                     device_id,                // device_id
1106                                     conv_desc.group_count()};
1107   AlgorithmConfig algorithm_config;
1108 #if TENSORFLOW_USE_ROCM
1109   // cudnn_use_autotune is applicable only the CUDA flow
1110   // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
1111   // if we do not have a cached algorithm_config for this conv_parameters
1112   cudnn_use_autotune = true;
1113 #endif
1114 
1115   if (cudnn_use_autotune &&
1116       !AutotuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
1117     profiler::ScopedAnnotation annotation("cudnn_autotuning");
1118     std::vector<std::unique_ptr<se::dnn::ConvolveExecutionPlan>> plans;
1119 #if GOOGLE_CUDA
1120     std::vector<AlgorithmDesc> algorithms;
1121     std::vector<AlgorithmConfig> configs;
1122     if (CudnnUseFrontend()) {
1123       OP_REQUIRES(
1124           ctx,
1125           stream->parent()->GetConvolveExecutionPlans(
1126               se::dnn::ConvolutionKind::FORWARD, se::dnn::ToDataType<T>::value,
1127               stream, input_desc, filter_desc, output_desc, conv_desc, &plans),
1128           errors::Unknown("Failed to get convolution algorithm. This is "
1129                           "probably because cuDNN failed to initialize, so try "
1130                           "looking to see if a warning log message was printed "
1131                           "above."));
1132       for (const auto& plan : plans) {
1133         configs.push_back(
1134             AlgorithmConfig(AlgorithmDesc{plan->getTag(), plan->get_raw_desc()},
1135                             plan->getWorkspaceSize()));
1136       }
1137     } else {
1138       OP_REQUIRES(
1139           ctx, stream->parent()->GetConvolveAlgorithms(&algorithms),
1140           errors::Unknown("Failed to get convolution algorithm. This is "
1141                           "probably because cuDNN failed to initialize, so try "
1142                           "looking to see if a warning log message was printed "
1143                           "above."));
1144       for (const auto& algorithm : algorithms) {
1145         configs.push_back(AlgorithmConfig(algorithm));
1146       }
1147     }
1148 
1149     se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
1150                                                 stream);
1151     se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
1152                                       se::GpuAsmOpts());
1153     se::DeviceMemory<T> output_tensor(
1154         WrapRedzoneBestEffort(&rz_allocator, output_ptr));
1155 
1156     std::vector<tensorflow::AutotuneResult> results;
1157     // TODO(reedwm): Warn if determinism is enabled after autotune is run
1158     for (const auto& profile_config : configs) {
1159       // TODO(zhengxq): profile each algorithm multiple times to better
1160       // accuracy.
1161       se::RedzoneAllocator rz_scratch_allocator(
1162           stream, &tf_allocator_adapter, se::GpuAsmOpts(),
1163           /*memory_limit=*/ConvolveScratchSize);
1164       DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1165       se::ScratchAllocator* allocator_used =
1166           !RedzoneCheckDisabled()
1167               ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
1168               : static_cast<se::ScratchAllocator*>(&scratch_allocator);
1169 
1170       ProfileResult profile_result;
1171       Status cudnn_launch_status;
1172       if (CudnnUseFrontend()) {
1173         cudnn_launch_status = stream->ConvolveWithExecutionPlan(
1174             input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1175             output_desc, &output_tensor, allocator_used, profile_config,
1176             &profile_result);
1177       } else {
1178         cudnn_launch_status = stream->ConvolveWithAlgorithm(
1179             input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1180             output_desc, &output_tensor, allocator_used, profile_config,
1181             &profile_result);
1182       }
1183 
1184       if (cudnn_launch_status.ok() && profile_result.is_valid()) {
1185         results.emplace_back();
1186         auto& result = results.back();
1187         if (CudnnUseFrontend()) {
1188           result.mutable_cuda_conv_plan()->set_exec_plan_id(
1189               profile_config.algorithm()->exec_plan_id());
1190         } else {
1191           result.mutable_conv()->set_algorithm(
1192               profile_config.algorithm()->algo_id());
1193           result.mutable_conv()->set_tensor_ops_enabled(
1194               profile_config.algorithm()->tensor_ops_enabled());
1195         }
1196 
1197         result.set_scratch_bytes(
1198             !RedzoneCheckDisabled()
1199                 ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
1200                 : scratch_allocator.TotalByteSize());
1201         *result.mutable_run_time() = proto_utils::ToDurationProto(
1202             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1203 
1204         CheckRedzones(rz_scratch_allocator, &result);
1205         CheckRedzones(rz_allocator, &result);
1206       } else if (CudnnUseFrontend()) {
1207         // When CuDNN frontend APIs are used, we need to make sure the profiling
1208         // results are one-to-one mapping of the "plans". So, we insert dummy
1209         // results when the excution fails.
1210         results.emplace_back();
1211         auto& result = results.back();
1212         result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
1213         result.mutable_failure()->set_msg(
1214             absl::StrCat("Profiling failure on CUDNN engine: ",
1215                          profile_config.algorithm()->exec_plan_id()));
1216       }
1217     }
1218 
1219 #elif TENSORFLOW_USE_ROCM
1220     DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1221 
1222     std::vector<ProfileResult> algorithms;
1223     OP_REQUIRES(
1224         ctx,
1225         stream->parent()->GetMIOpenConvolveAlgorithms(
1226             se::dnn::ConvolutionKind::FORWARD, se::dnn::ToDataType<T>::value,
1227             stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc,
1228             output_ptr, conv_desc, &scratch_allocator, &algorithms),
1229         errors::Unknown(
1230             "Failed to get convolution algorithm. This is probably "
1231             "because MIOpen failed to initialize, so try looking to "
1232             "see if a warning log message was printed above."));
1233     se::DeviceMemory<T> output_tensor = output_ptr;
1234 
1235     std::vector<tensorflow::AutotuneResult> results;
1236     if (algorithms.size() == 1) {
1237       auto profile_result = algorithms[0];
1238       results.emplace_back();
1239       auto& result = results.back();
1240       result.mutable_conv()->set_algorithm(
1241           profile_result.algorithm().algo_id());
1242       result.mutable_conv()->set_tensor_ops_enabled(
1243           profile_result.algorithm().tensor_ops_enabled());
1244 
1245       result.set_scratch_bytes(profile_result.scratch_size());
1246       *result.mutable_run_time() = proto_utils::ToDurationProto(
1247           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1248     } else {
1249       for (auto miopen_algorithm : algorithms) {
1250         auto profile_algorithm = miopen_algorithm.algorithm();
1251         ProfileResult profile_result;
1252         auto miopen_launch_status = stream->ConvolveWithAlgorithm(
1253             input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1254             output_desc, &output_ptr, &scratch_allocator,
1255             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
1256             &profile_result);
1257         if (miopen_launch_status.ok() && profile_result.is_valid()) {
1258           results.emplace_back();
1259           auto& result = results.back();
1260           result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1261           result.mutable_conv()->set_tensor_ops_enabled(
1262               profile_algorithm.tensor_ops_enabled());
1263 
1264           result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1265           *result.mutable_run_time() = proto_utils::ToDurationProto(
1266               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1267         }
1268       }
1269     }
1270 #endif
1271     LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
1272                            se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
1273                            output_tensor, input_desc, filter_desc, output_desc,
1274                            conv_desc, stream->parent(), results);
1275 
1276     if (CudnnUseFrontend()) {
1277       OP_REQUIRES_OK(
1278           ctx, BestCudnnConvAlgorithm(results, &plans, &algorithm_config));
1279 
1280     } else {
1281       OP_REQUIRES_OK(
1282           ctx, BestCudnnConvAlgorithm(results, nullptr, &algorithm_config));
1283     }
1284 
1285     AutotuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
1286   }
1287 
1288   Status cudnn_launch_status;
1289   DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1290   if (CudnnUseFrontend()) {
1291     if (algorithm_config.algorithm().has_value()) {
1292       VLOG(4) << "Conv2D Execution Plan: "
1293               << algorithm_config.algorithm()->exec_plan_id();
1294     } else {
1295       VLOG(4) << "Convolution Autotune has been turned off";
1296     }
1297     cudnn_launch_status = stream->ConvolveWithExecutionPlan(
1298         input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1299         &output_ptr, &scratch_allocator, algorithm_config, nullptr);
1300   } else {
1301     VLOG(4) << "Convolution Algorithm: "
1302             << algorithm_config.algorithm()->algo_id();
1303     VLOG(4) << "tensor_ops_enabled: "
1304             << algorithm_config.algorithm()->tensor_ops_enabled();
1305 
1306     cudnn_launch_status = stream->ConvolveWithAlgorithm(
1307         input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1308         &output_ptr, &scratch_allocator, algorithm_config, nullptr);
1309   }
1310 
1311   if (!cudnn_launch_status.ok()) {
1312     ctx->SetStatus(cudnn_launch_status);
1313   }
1314 
1315   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1316     VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
1317     functor::NCHWToNHWC<GPUDevice, T, 4>()(
1318         ctx->eigen_device<GPUDevice>(),
1319         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
1320         output->tensor<T, 4>());
1321   }
1322 }
1323 
1324 // Forward declarations of the functor specializations for GPU.
1325 namespace functor {
1326 #define DECLARE_GPU_SPEC(T)                                                 \
1327   template <>                                                               \
1328   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1329       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1330       typename TTypes<T, 4>::ConstTensor input,                             \
1331       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1332       int col_stride, int row_dilation, int col_dilation,                   \
1333       const Eigen::PaddingType& padding,                                    \
1334       const Eigen::NoOpOutputKernel& output_kernel);                        \
1335   template <>                                                               \
1336   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1337       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1338       typename TTypes<T, 4>::ConstTensor input,                             \
1339       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1340       int col_stride, int row_dilation, int col_dilation, int padding_top,  \
1341       int padding_bottom, int padding_left, int padding_right,              \
1342       const Eigen::NoOpOutputKernel& output_kernel);                        \
1343   extern template struct SpatialConvolution<GPUDevice, T>;                  \
1344   template <>                                                               \
1345   void MatMulConvFunctor<GPUDevice, T>::operator()(                         \
1346       const GPUDevice& d, typename TTypes<T, 2>::Tensor out,                \
1347       typename TTypes<T, 2>::ConstTensor in0,                               \
1348       typename TTypes<T, 2>::ConstTensor in1,                               \
1349       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, \
1350       const Eigen::NoOpOutputKernel& output_kernel);                        \
1351   extern template struct MatMulConvFunctor<GPUDevice, T>;                   \
1352   template <>                                                               \
1353   void TransformFilter<GPUDevice, T, int, 4>::operator()(                   \
1354       const GPUDevice& d, FilterTensorFormat dst_filter_format,             \
1355       typename TTypes<T, 4, int>::ConstTensor in,                           \
1356       typename TTypes<T, 4, int>::Tensor out);                              \
1357   extern template struct TransformFilter<GPUDevice, T, int, 4>;             \
1358   template <>                                                               \
1359   void PadInput<GPUDevice, T, int, 4>::operator()(                          \
1360       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,       \
1361       const std::array<int, 2>& padding_left,                               \
1362       const std::array<int, 2>& padding_right,                              \
1363       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format,     \
1364       const T& padding_value);                                              \
1365   extern template struct PadInput<GPUDevice, T, int, 4>
1366 
1367 DECLARE_GPU_SPEC(float);
1368 DECLARE_GPU_SPEC(Eigen::half);
1369 DECLARE_GPU_SPEC(double);
1370 DECLARE_GPU_SPEC(int32);
1371 #undef DECLARE_GPU_SPEC
1372 
1373 }  // namespace functor
1374 
1375 // Registration of the GPU implementations.
1376 REGISTER_KERNEL_BUILDER(
1377     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
1378     Conv2DOp<GPUDevice, Eigen::half>);
1379 REGISTER_KERNEL_BUILDER(
1380     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1381     Conv2DOp<GPUDevice, float>);
1382 REGISTER_KERNEL_BUILDER(
1383     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
1384     Conv2DOp<GPUDevice, double>);
1385 REGISTER_KERNEL_BUILDER(
1386     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
1387     Conv2DOp<GPUDevice, int32>);
1388 
1389 // To be used inside depthwise_conv_op.cc.
1390 template struct LaunchConv2DOp<GPUDevice, float>;
1391 template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
1392 template struct LaunchConv2DOp<GPUDevice, double>;
1393 
1394 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1395 
1396 }  // namespace tensorflow
1397