• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/framework/kernel_shape_util.h"
20 #include "tensorflow/core/framework/numeric_op.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_slice.h"
26 #include "tensorflow/core/framework/tensor_util.h"
27 #include "tensorflow/core/kernels/conv_2d.h"
28 #include "tensorflow/core/kernels/conv_3d.h"
29 #include "tensorflow/core/kernels/conv_grad_ops.h"
30 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
31 #include "tensorflow/core/kernels/conv_ops_gpu.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/util/padding.h"
35 #include "tensorflow/core/util/tensor_format.h"
36 #include "tensorflow/core/util/use_cudnn.h"
37 #include "tensorflow/core/util/work_sharder.h"
38 
39 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
40 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
41 #endif
42 
43 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
44 #include "tensorflow/core/platform/stream_executor.h"
45 using stream_executor::dnn::DimIndex;
46 #include "tensorflow/core/protobuf/autotuning.pb.h"
47 #include "tensorflow/core/util/proto/proto_utils.h"
48 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
49 #if GOOGLE_CUDA
50 #include "third_party/gpus/cudnn/cudnn.h"
51 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
52 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
53 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
54 #endif  // GOOGLE_CUDA
55 
56 namespace {
57 
58 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
59 // conv_grad_input_ops_3d.cc.
60 
61 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
62 
63 // "Depth" is already used for the channel dimension, so for the third spatial
64 // dimension in this file we use "plane", although in NDHWC layout it's
65 // indicated with a "D".
66 
67 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
68 // order (planes, height, width, depth), constructed from patches in 'col_data',
69 // which is required to be in storage order (out_planes * out_height *
70 // out_width, filter_planes, filter_height, filter_width, in_depth).
71 //
72 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
73 template <typename T>
Col2im(const T * col_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)74 void Col2im(const T* col_data, const int depth, const int planes,
75             const int height, const int width, const int filter_p,
76             const int filter_h, const int filter_w, const int pad_pt,
77             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
78             const int pad_r, const int stride_p, const int stride_h,
79             const int stride_w, T* im_data) {
80   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
81   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
82   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
83   int p_pad = -pad_pt;
84   for (int p = 0; p < planes_col; ++p) {
85     int h_pad = -pad_t;
86     for (int h = 0; h < height_col; ++h) {
87       int w_pad = -pad_l;
88       for (int w = 0; w < width_col; ++w) {
89         T* im_patch_data =
90             im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
91         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
92           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
93             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
94               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
95                   iw < width) {
96                 for (int i = 0; i < depth; ++i) {
97                   im_patch_data[i] += col_data[i];
98                 }
99               }
100               im_patch_data += depth;
101               col_data += depth;
102             }
103             // Jump over remaining number of depth.
104             im_patch_data += depth * (width - filter_w);
105           }
106           // Jump over remaining number of (depth * width).
107           im_patch_data += (depth * width) * (height - filter_h);
108         }
109         w_pad += stride_w;
110       }
111       h_pad += stride_h;
112     }
113     p_pad += stride_p;
114   }
115 }
116 
117 // Returns in 'col_data', image patches in storage order (planes, height, width,
118 // depth) extracted from image at 'input_data', which is required to be in
119 // storage order (batch, planes, height, width, depth).
120 //
121 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
122 template <typename T>
Im2col(const T * input_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * col_data)123 void Im2col(const T* input_data, const int depth, const int planes,
124             const int height, const int width, const int filter_p,
125             const int filter_h, const int filter_w, const int pad_pt,
126             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
127             const int pad_r, const int stride_p, const int stride_h,
128             const int stride_w, T* col_data) {
129   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
130   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
131   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
132 
133   int p_pad = -pad_pt;
134   for (int p = 0; p < planes_col; ++p) {
135     int h_pad = -pad_t;
136     for (int h = 0; h < height_col; ++h) {
137       int w_pad = -pad_l;
138       for (int w = 0; w < width_col; ++w) {
139         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
140           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
141             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
142               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
143                   iw < width) {
144                 memcpy(col_data,
145                        input_data +
146                            (ip * height * width + ih * width + iw) * depth,
147                        sizeof(T) * depth);
148               } else {
149                 // This should be simply padded with zero.
150                 memset(col_data, 0, sizeof(T) * depth);
151               }
152               col_data += depth;
153             }
154           }
155         }
156         w_pad += stride_w;
157       }
158       h_pad += stride_h;
159     }
160     p_pad += stride_p;
161   }
162 }
163 
164 }  // namespace
165 
166 namespace tensorflow {
167 
168 typedef Eigen::ThreadPoolDevice CPUDevice;
169 typedef Eigen::GpuDevice GPUDevice;
170 
171 // Backprop for input that offloads computation to
172 // Eigen::CuboidConvolutionBackwardInput.
173 template <typename Device, class T>
174 class Conv3DBackpropInputOp : public OpKernel {
175  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)176   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
177       : OpKernel(context),
178         data_format_(FORMAT_NHWC),
179         takes_shape_(type_string().find("V2") != std::string::npos) {
180     // data_format is only available in V2.
181     if (takes_shape_) {
182       string data_format;
183       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
184       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
185                   errors::InvalidArgument("Invalid data format"));
186       OP_REQUIRES(
187           context, data_format_ == FORMAT_NHWC,
188           errors::InvalidArgument(
189               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
190     }
191 
192     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
193     OP_REQUIRES(context, dilation_.size() == 5,
194                 errors::InvalidArgument("Dilation rates field must "
195                                         "specify 5 dimensions"));
196     OP_REQUIRES(context,
197                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
198                  GetTensorDim(dilation_, data_format_, 'N') == 1),
199                 errors::InvalidArgument(
200                     "Current implementation does not yet support "
201                     "dilation rates in the batch and depth dimensions."));
202 
203     // TODO(yangzihao): Add CPU version of dilated conv 3D.
204     OP_REQUIRES(context,
205                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
206                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
207                  GetTensorDim(dilation_, data_format_, '2') == 1),
208                 errors::InvalidArgument(
209                     "Current CPU implementation does not yet support "
210                     "dilation rates larger than 1."));
211 
212     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
213     OP_REQUIRES(context, stride_.size() == 5,
214                 errors::InvalidArgument("Sliding window strides field must "
215                                         "specify 5 dimensions"));
216     OP_REQUIRES(
217         context,
218         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
219          GetTensorDim(stride_, data_format_, 'N') == 1),
220         errors::InvalidArgument("Current implementation does not yet support "
221                                 "strides in the batch and depth dimensions."));
222     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
223   }
224 
Compute(OpKernelContext * context)225   void Compute(OpKernelContext* context) override {
226     const Tensor& filter = context->input(1);
227     const TensorShape& filter_shape = filter.shape();
228 
229     const Tensor& out_backprop = context->input(2);
230     const TensorShape& out_backprop_shape = out_backprop.shape();
231 
232     TensorShape input_shape;
233     if (takes_shape_) {
234       const Tensor& input_sizes = context->input(0);
235       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
236       // input_sizes.
237       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
238     } else {
239       input_shape = context->input(0).shape();
240     }
241 
242     ConvBackpropDimensions dims;
243     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
244                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
245                                 input_shape, filter_shape, out_backprop_shape,
246                                 stride_, padding_, data_format_, &dims));
247 
248     Tensor* in_backprop;
249     OP_REQUIRES_OK(context,
250                    context->allocate_output(0, input_shape, &in_backprop));
251 
252     functor::CuboidConvolutionBackwardInput<Device, T>()(
253         context->eigen_device<Device>(),
254         in_backprop->tensor<T, 5>(),                     // input_backward
255         filter.tensor<T, 5>(),                           // filter
256         out_backprop.tensor<T, 5>(),                     // output_backward
257         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
258         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
259         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
260   }
261 
262  private:
263   std::vector<int32> dilation_;
264   std::vector<int32> stride_;
265   Padding padding_;
266   TensorFormat data_format_;
267   bool takes_shape_;
268 
269   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
270 };
271 
272 // Custom backprop for input that explicitly does the work sharding and calls
273 // Eigen only to multiply matrices.
274 template <typename Device, class T>
275 class Conv3DCustomBackpropInputOp : public OpKernel {
276   // Limit the maximum size of allocated temporary buffer to
277   // kMaxTempAllocationOverhead times the size of the input tensors (input,
278   // filter, out_backprop). If the size of the temporary buffer exceeds this
279   // limit, fallback on Eigen implementation.
280   static constexpr int kMaxTempAllocationOverhead = 25;
281 
282  public:
Conv3DCustomBackpropInputOp(OpKernelConstruction * context)283   explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
284       : OpKernel(context),
285         data_format_(FORMAT_NHWC),
286         takes_shape_(type_string().find("V2") != std::string::npos) {
287     // data_format is only available in V2.
288     if (takes_shape_) {
289       string data_format;
290       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
291       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
292                   errors::InvalidArgument("Invalid data format"));
293       OP_REQUIRES(
294           context, data_format_ == FORMAT_NHWC,
295           errors::InvalidArgument(
296               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
297     }
298 
299     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
300     OP_REQUIRES(context, dilation_.size() == 5,
301                 errors::InvalidArgument("Dilation rates field must "
302                                         "specify 5 dimensions"));
303     OP_REQUIRES(context,
304                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
305                  GetTensorDim(dilation_, data_format_, 'N') == 1),
306                 errors::InvalidArgument(
307                     "Current implementation does not yet support "
308                     "dilation rates in the batch and depth dimensions."));
309 
310     // TODO(yangzihao): Add CPU version of dilated conv 3D.
311     OP_REQUIRES(context,
312                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
313                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
314                  GetTensorDim(dilation_, data_format_, '2') == 1),
315                 errors::InvalidArgument(
316                     "Current CPU implementation does not yet support "
317                     "dilation rates larger than 1."));
318 
319     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
320     OP_REQUIRES(context, stride_.size() == 5,
321                 errors::InvalidArgument("Sliding window strides field must "
322                                         "specify 5 dimensions"));
323     OP_REQUIRES(
324         context,
325         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
326          GetTensorDim(stride_, data_format_, 'N') == 1),
327         errors::InvalidArgument("Current implementation does not yet support "
328                                 "strides in the batch and depth dimensions."));
329     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
330   }
331 
Compute(OpKernelContext * context)332   void Compute(OpKernelContext* context) override {
333     const Tensor& filter = context->input(1);
334     const TensorShape& filter_shape = filter.shape();
335 
336     const Tensor& out_backprop = context->input(2);
337     const TensorShape& out_backprop_shape = out_backprop.shape();
338 
339     TensorShape input_shape;
340     if (takes_shape_) {
341       const Tensor& input_sizes = context->input(0);
342       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
343       // input_sizes.
344       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
345     } else {
346       input_shape = context->input(0).shape();
347     }
348 
349     ConvBackpropDimensions dims;
350     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
351                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
352                                 input_shape, filter_shape, out_backprop_shape,
353                                 stride_, padding_, data_format_, &dims));
354 
355     Tensor* in_backprop;
356     OP_REQUIRES_OK(context,
357                    context->allocate_output(0, input_shape, &in_backprop));
358 
359     int64 top_pad_planes, bottom_pad_planes;
360     int64 top_pad_rows, bottom_pad_rows;
361     int64 left_pad_cols, right_pad_cols;
362 
363     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
364                                 dims.spatial_dims[0].input_size,
365                                 dims.spatial_dims[0].filter_size,
366                                 dims.spatial_dims[0].stride, padding_,
367                                 &dims.spatial_dims[0].output_size,
368                                 &top_pad_planes, &bottom_pad_planes));
369     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
370                                 dims.spatial_dims[1].input_size,
371                                 dims.spatial_dims[1].filter_size,
372                                 dims.spatial_dims[1].stride, padding_,
373                                 &dims.spatial_dims[1].output_size,
374                                 &top_pad_rows, &bottom_pad_rows));
375     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
376                                 dims.spatial_dims[2].input_size,
377                                 dims.spatial_dims[2].filter_size,
378                                 dims.spatial_dims[2].stride, padding_,
379                                 &dims.spatial_dims[2].output_size,
380                                 &left_pad_cols, &right_pad_cols));
381 
382     // TODO(ezhulenev): Extract work size and shard estimation to shared
383     // functions in conv_grad_ops, and update 2d convolution backprop.
384 
385     // The total dimension size of each kernel.
386     const int64 filter_total_size =
387         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
388         dims.spatial_dims[2].filter_size * dims.in_depth;
389 
390     // The output image size is the spatial size of the output.
391     const int64 output_image_size = dims.spatial_dims[0].output_size *
392                                     dims.spatial_dims[1].output_size *
393                                     dims.spatial_dims[2].output_size;
394 
395     const auto cache_sizes = Eigen::internal::CacheSizes();
396     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
397 
398     // Use L3 cache size as target working set size.
399     const size_t target_working_set_size = l3_cache_size / sizeof(T);
400 
401     // Calculate size of matrices involved in MatMul: C = A x B.
402     const int64 size_A = output_image_size * dims.out_depth;
403 
404     const int64 size_B = filter_total_size * dims.out_depth;
405 
406     const int64 size_C = output_image_size * filter_total_size;
407 
408     const int64 work_unit_size = size_A + size_B + size_C;
409 
410     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
411 
412     // Use parallel tensor contractions if there is no batching.
413     //
414     // Compared to Conv2D code, this version is missing work size estimation. In
415     // benchmarks I didn't find a case when it's beneficial to run parallel
416     // contraction compared to sharding and matmuls.
417     const bool use_parallel_contraction = dims.batch_size == 1;
418 
419     const size_t shard_size =
420         use_parallel_contraction
421             ? 1
422             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
423 
424     // Total number of elements in all the tensors used by this kernel.
425     int64 total_tensor_elements = input_shape.num_elements() +
426                                   filter_shape.num_elements() +
427                                   out_backprop_shape.num_elements();
428 
429     // Shape of the temporary workspace buffer.
430     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
431                                     static_cast<int64>(output_image_size),
432                                     static_cast<int64>(filter_total_size)};
433     int64 col_buffer_elements = col_buffer_shape.num_elements();
434 
435     // If the temporary allocation overhead is too large, fallback on Eigen
436     // implementation which requires much less memory.
437     int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
438     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
439       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
440                  "col_buffer_overhead="
441               << col_buffer_overhead;
442 
443       functor::CuboidConvolutionBackwardInput<Device, T>()(
444           context->eigen_device<Device>(),
445           in_backprop->tensor<T, 5>(),                     // input_backward
446           filter.tensor<T, 5>(),                           // filter
447           out_backprop.tensor<T, 5>(),                     // output_backward
448           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
449           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
450           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
451 
452       return;
453     }
454 
455     Tensor col_buffer;
456     OP_REQUIRES_OK(context,
457                    context->allocate_temp(DataTypeToEnum<T>::value,
458                                           col_buffer_shape, &col_buffer));
459 
460     // The input offset corresponding to a single input image.
461     const int64 input_offset = dims.spatial_dims[0].input_size *
462                                dims.spatial_dims[1].input_size *
463                                dims.spatial_dims[2].input_size * dims.in_depth;
464 
465     // The output offset corresponding to a single output image.
466     const int64 output_offset =
467         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
468         dims.spatial_dims[2].output_size * dims.out_depth;
469 
470     const T* filter_data = filter.template flat<T>().data();
471     T* col_buffer_data = col_buffer.template flat<T>().data();
472     const T* out_backprop_data = out_backprop.template flat<T>().data();
473 
474     auto in_backprop_flat = in_backprop->template flat<T>();
475     T* input_backprop_data = in_backprop_flat.data();
476     in_backprop_flat.device(context->eigen_device<Device>()) =
477         in_backprop_flat.constant(T(0));
478 
479     if (use_parallel_contraction) {
480       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
481                                Eigen::Unaligned>
482           TensorMap;
483       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
484                                Eigen::Unaligned>
485           ConstTensorMap;
486 
487       // Initialize contraction dims (we need to transpose 'B' below).
488       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
489       contract_dims[0].first = 1;
490       contract_dims[0].second = 1;
491 
492       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
493         // Compute gradient into col_buffer.
494         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
495 
496         ConstTensorMap A(out_backprop_data + output_offset * image_id,
497                          output_image_size, dims.out_depth);
498         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
499 
500         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
501 
502         Col2im<T>(col_buffer_data, dims.in_depth,
503                   // Input spatial dimensions.
504                   dims.spatial_dims[0].input_size,  // input planes
505                   dims.spatial_dims[1].input_size,  // input rows
506                   dims.spatial_dims[2].input_size,  // input cols
507                   // Filter spatial dimensions.
508                   dims.spatial_dims[0].filter_size,  // filter planes
509                   dims.spatial_dims[1].filter_size,  // filter rows
510                   dims.spatial_dims[2].filter_size,  // filter cols
511                   // Spatial padding.
512                   top_pad_planes, top_pad_rows, left_pad_cols,
513                   bottom_pad_planes, bottom_pad_rows, right_pad_cols,
514                   // Spatial striding.
515                   dims.spatial_dims[0].stride,  // stride planes
516                   dims.spatial_dims[1].stride,  // stride rows
517                   dims.spatial_dims[2].stride,  // stride cols
518                   input_backprop_data);
519 
520         input_backprop_data += input_offset;
521       }
522     } else {
523       typedef Eigen::Map<
524           Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
525           MatrixMap;
526       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
527                                              Eigen::RowMajor>>
528           ConstMatrixMap;
529 
530       for (int image_id = 0; image_id < dims.batch_size;
531            image_id += shard_size) {
532         const int shard_limit =
533             std::min(static_cast<int>(shard_size),
534                      static_cast<int>(dims.batch_size) - image_id);
535 
536         auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
537                       &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
538                       &output_image_size, &filter_total_size,
539                       &input_backprop_data, &col_buffer_data,
540                       &out_backprop_data, &filter_data, &input_offset,
541                       &output_offset, &size_C](int64 start, int64 limit) {
542           for (int shard_id = start; shard_id < limit; ++shard_id) {
543             T* im2col_buf = col_buffer_data + shard_id * size_C;
544             T* input_data = input_backprop_data + shard_id * input_offset;
545             const T* out_data = out_backprop_data + shard_id * output_offset;
546 
547             // Compute gradient into 'im2col_buf'.
548             MatrixMap C(im2col_buf, output_image_size, filter_total_size);
549 
550             ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
551             ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
552 
553             C.noalias() = A * B.transpose();
554 
555             Col2im<T>(im2col_buf, dims.in_depth,
556                       // Input spatial dimensions.
557                       dims.spatial_dims[0].input_size,  // input planes
558                       dims.spatial_dims[1].input_size,  // input rows
559                       dims.spatial_dims[2].input_size,  // input cols
560                       // Filter spatial dimensions.
561                       dims.spatial_dims[0].filter_size,  // filter planes
562                       dims.spatial_dims[1].filter_size,  // filter rows
563                       dims.spatial_dims[2].filter_size,  // filter cols
564                       // Spatial padding.
565                       top_pad_planes, top_pad_rows, left_pad_cols,
566                       bottom_pad_planes, bottom_pad_rows, right_pad_cols,
567                       // Spatial striding.
568                       dims.spatial_dims[0].stride,  // stride planes
569                       dims.spatial_dims[1].stride,  // stride rows
570                       dims.spatial_dims[2].stride,  // stride cols
571                       input_data);
572           }
573         };
574         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
575               work_unit_size, shard);
576 
577         input_backprop_data += input_offset * shard_limit;
578         out_backprop_data += output_offset * shard_limit;
579       }
580     }
581   }
582 
583  private:
584   std::vector<int32> dilation_;
585   std::vector<int32> stride_;
586   Padding padding_;
587   TensorFormat data_format_;
588   bool takes_shape_;
589 
590   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
591 };
592 
593 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
594 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
595 
596 #define REGISTER_CPU_KERNEL(T)                                                 \
597   REGISTER_KERNEL_BUILDER(                                                     \
598       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
599       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
600   REGISTER_KERNEL_BUILDER(                                                     \
601       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
602       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
603   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
604                               .Device(DEVICE_CPU)                              \
605                               .Label("custom")                                 \
606                               .TypeConstraint<T>("T"),                         \
607                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
608   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
609                               .Device(DEVICE_CPU)                              \
610                               .Label("custom")                                 \
611                               .TypeConstraint<T>("T"),                         \
612                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
613   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
614                               .Device(DEVICE_CPU)                              \
615                               .Label("eigen_tensor")                           \
616                               .TypeConstraint<T>("T"),                         \
617                           Conv3DBackpropInputOp<CPUDevice, T>);                \
618   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
619                               .Device(DEVICE_CPU)                              \
620                               .Label("eigen_tensor")                           \
621                               .TypeConstraint<T>("T"),                         \
622                           Conv3DBackpropInputOp<CPUDevice, T>);
623 
624 TF_CALL_half(REGISTER_CPU_KERNEL);
625 TF_CALL_float(REGISTER_CPU_KERNEL);
626 TF_CALL_double(REGISTER_CPU_KERNEL);
627 #undef REGISTER_CPU_KERNEL
628 
629 // Backprop for filter that offloads computation to
630 // Eigen::CuboidConvolutionBackwardFilter.
631 template <typename Device, class T>
632 class Conv3DBackpropFilterOp : public OpKernel {
633  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)634   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
635       : OpKernel(context),
636         data_format_(FORMAT_NHWC),
637         takes_shape_(type_string().find("V2") != std::string::npos) {
638     // data_format is only available in V2.
639     if (takes_shape_) {
640       string data_format;
641       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
642       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
643                   errors::InvalidArgument("Invalid data format"));
644       OP_REQUIRES(
645           context, data_format_ == FORMAT_NHWC,
646           errors::InvalidArgument(
647               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
648     }
649 
650     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
651     OP_REQUIRES(context, dilation_.size() == 5,
652                 errors::InvalidArgument("Dilation rates field must "
653                                         "specify 5 dimensions"));
654     OP_REQUIRES(context,
655                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
656                  GetTensorDim(dilation_, data_format_, 'N') == 1),
657                 errors::InvalidArgument(
658                     "Current implementation does not yet support "
659                     "dilation rates in the batch and depth dimensions."));
660 
661     // TODO(yangzihao): Add CPU version of dilated conv 3D.
662     OP_REQUIRES(context,
663                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
664                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
665                  GetTensorDim(dilation_, data_format_, '2') == 1),
666                 errors::InvalidArgument(
667                     "Current CPU implementation does not yet support "
668                     "dilation rates larger than 1."));
669 
670     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
671     OP_REQUIRES(context, stride_.size() == 5,
672                 errors::InvalidArgument("Sliding window strides field must "
673                                         "specify 5 dimensions"));
674     OP_REQUIRES(
675         context,
676         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
677          GetTensorDim(stride_, data_format_, 'N') == 1),
678         errors::InvalidArgument("Current implementation does not yet support "
679                                 "strides in the batch and depth dimensions."));
680     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
681   }
682 
Compute(OpKernelContext * context)683   void Compute(OpKernelContext* context) override {
684     const Tensor& input = context->input(0);
685     const TensorShape& input_shape = input.shape();
686 
687     const Tensor& out_backprop = context->input(2);
688     const TensorShape& out_backprop_shape = out_backprop.shape();
689 
690     TensorShape filter_shape;
691     if (takes_shape_) {
692       const Tensor& filter_sizes = context->input(1);
693       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
694                                   filter_sizes.vec<int32>(), &filter_shape));
695     } else {
696       filter_shape = context->input(1).shape();
697     }
698 
699     ConvBackpropDimensions dims;
700     OP_REQUIRES_OK(context,
701                    ConvBackpropComputeDimensions(
702                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
703                        input_shape, filter_shape, out_backprop_shape, stride_,
704                        padding_, data_format_, &dims));
705 
706     Tensor* filter_backprop;
707     OP_REQUIRES_OK(context,
708                    context->allocate_output(0, filter_shape, &filter_backprop));
709 
710     if (input_shape.num_elements() == 0) {
711       filter_backprop->template flat<T>().setZero();
712       return;
713     }
714 
715     functor::CuboidConvolutionBackwardFilter<Device, T>()(
716         context->eigen_device<Device>(),
717         filter_backprop->tensor<T, 5>(),                 // filter_backward
718         input.tensor<T, 5>(),                            // input
719         out_backprop.tensor<T, 5>(),                     // output_backward
720         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
721         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
722         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
723   }
724 
725  private:
726   std::vector<int32> dilation_;
727   std::vector<int32> stride_;
728   Padding padding_;
729   TensorFormat data_format_;
730   bool takes_shape_;
731 
732   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
733 };
734 
735 // Custom backprop for filter that explicitly does the work sharding and calls
736 // Eigen only to multiply matrices.
737 template <typename Device, class T>
738 class Conv3DCustomBackpropFilterOp : public OpKernel {
739   // Limit the maximum size of allocated temporary buffer to
740   // kMaxTempAllocationOverhead times the size of the input tensors (input,
741   // filter, out_backprop). If the size of the temporary buffer exceeds this
742   // limit, fallback on Eigen implementation.
743   static constexpr int kMaxTempAllocationOverhead = 25;
744 
745  public:
Conv3DCustomBackpropFilterOp(OpKernelConstruction * context)746   explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
747       : OpKernel(context),
748         data_format_(FORMAT_NHWC),
749         takes_shape_(type_string().find("V2") != std::string::npos) {
750     // data_format is only available in V2.
751     if (takes_shape_) {
752       string data_format;
753       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
754       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
755                   errors::InvalidArgument("Invalid data format"));
756       OP_REQUIRES(
757           context, data_format_ == FORMAT_NHWC,
758           errors::InvalidArgument(
759               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
760     }
761 
762     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
763     OP_REQUIRES(context, dilation_.size() == 5,
764                 errors::InvalidArgument("Dilation rates field must "
765                                         "specify 5 dimensions"));
766     OP_REQUIRES(context,
767                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
768                  GetTensorDim(dilation_, data_format_, 'N') == 1),
769                 errors::InvalidArgument(
770                     "Current implementation does not yet support "
771                     "dilation rates in the batch and depth dimensions."));
772 
773     // TODO(yangzihao): Add CPU version of dilated conv 3D.
774     OP_REQUIRES(context,
775                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
776                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
777                  GetTensorDim(dilation_, data_format_, '2') == 1),
778                 errors::InvalidArgument(
779                     "Current CPU implementation does not yet support "
780                     "dilation rates larger than 1."));
781 
782     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
783     OP_REQUIRES(context, stride_.size() == 5,
784                 errors::InvalidArgument("Sliding window strides field must "
785                                         "specify 5 dimensions"));
786     OP_REQUIRES(
787         context,
788         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
789          GetTensorDim(stride_, data_format_, 'N') == 1),
790         errors::InvalidArgument("Current implementation does not yet support "
791                                 "strides in the batch and depth dimensions."));
792     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
793   }
794 
Compute(OpKernelContext * context)795   void Compute(OpKernelContext* context) override {
796     const Tensor& input = context->input(0);
797     const TensorShape& input_shape = input.shape();
798 
799     const Tensor& out_backprop = context->input(2);
800     const TensorShape& out_backprop_shape = out_backprop.shape();
801 
802     TensorShape filter_shape;
803     if (takes_shape_) {
804       const Tensor& filter_sizes = context->input(1);
805       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
806                                   filter_sizes.vec<int32>(), &filter_shape));
807     } else {
808       filter_shape = context->input(1).shape();
809     }
810 
811     ConvBackpropDimensions dims;
812     OP_REQUIRES_OK(context,
813                    ConvBackpropComputeDimensions(
814                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
815                        input_shape, filter_shape, out_backprop_shape, stride_,
816                        padding_, data_format_, &dims));
817 
818     Tensor* filter_backprop;
819     OP_REQUIRES_OK(context,
820                    context->allocate_output(0, filter_shape, &filter_backprop));
821 
822     if (input_shape.num_elements() == 0) {
823       filter_backprop->template flat<T>().setZero();
824       return;
825     }
826 
827     int64 top_pad_planes, bottom_pad_planes;
828     int64 top_pad_rows, bottom_pad_rows;
829     int64 left_pad_cols, right_pad_cols;
830 
831     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
832                                 dims.spatial_dims[0].input_size,
833                                 dims.spatial_dims[0].filter_size,
834                                 dims.spatial_dims[0].stride, padding_,
835                                 &dims.spatial_dims[0].output_size,
836                                 &top_pad_planes, &bottom_pad_planes));
837     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
838                                 dims.spatial_dims[1].input_size,
839                                 dims.spatial_dims[1].filter_size,
840                                 dims.spatial_dims[1].stride, padding_,
841                                 &dims.spatial_dims[1].output_size,
842                                 &top_pad_rows, &bottom_pad_rows));
843     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
844                                 dims.spatial_dims[2].input_size,
845                                 dims.spatial_dims[2].filter_size,
846                                 dims.spatial_dims[2].stride, padding_,
847                                 &dims.spatial_dims[2].output_size,
848                                 &left_pad_cols, &right_pad_cols));
849 
850     // TODO(ezhulenev): Extract work size and shard estimation to shared
851     // functions in conv_grad_ops, and update 2d convolution backprop.
852 
853     // The total dimension size of each kernel.
854     const int64 filter_total_size =
855         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
856         dims.spatial_dims[2].filter_size * dims.in_depth;
857     // The output image size is the spatial size of the output.
858     const int64 output_image_size = dims.spatial_dims[0].output_size *
859                                     dims.spatial_dims[1].output_size *
860                                     dims.spatial_dims[2].output_size;
861 
862     // Shard 'batch' images (volumes) into 'shard_size' groups of images
863     // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
864     // dividing the L3 cache size ('target_working_set_size') by the matmul size
865     // of an individual image ('work_unit_size').
866 
867     const auto cache_sizes = Eigen::internal::CacheSizes();
868     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
869 
870     // TODO(andydavis)
871     // *) Consider reducing 'target_working_set_size' if L3 is shared by
872     //    other concurrently running tensorflow ops.
873     const size_t target_working_set_size = l3_cache_size / sizeof(T);
874 
875     const int64 size_A = output_image_size * filter_total_size;
876 
877     const int64 size_B = output_image_size * dims.out_depth;
878 
879     const int64 size_C = filter_total_size * dims.out_depth;
880 
881     const int64 work_unit_size = size_A + size_B + size_C;
882 
883     const size_t shard_size =
884         (target_working_set_size + work_unit_size - 1) / work_unit_size;
885 
886     // Total number of elements in all the tensors used by this kernel.
887     int64 total_tensor_elements = input_shape.num_elements() +
888                                   filter_shape.num_elements() +
889                                   out_backprop_shape.num_elements();
890 
891     // Shape of the temporary workspace buffer.
892     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
893                                     static_cast<int64>(output_image_size),
894                                     static_cast<int64>(filter_total_size)};
895     int64 col_buffer_elements = col_buffer_shape.num_elements();
896 
897     // If the temporary allocation overhead is too large, fallback on Eigen
898     // implementation which requires much less memory.
899     int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
900     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
901       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
902                  "col_buffer_overhead="
903               << col_buffer_overhead;
904 
905       functor::CuboidConvolutionBackwardFilter<Device, T>()(
906           context->eigen_device<Device>(),
907           filter_backprop->tensor<T, 5>(),                 // filter_backward
908           input.tensor<T, 5>(),                            // input
909           out_backprop.tensor<T, 5>(),                     // output_backward
910           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
911           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
912           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
913 
914       return;
915     }
916 
917     Tensor col_buffer;
918     OP_REQUIRES_OK(context,
919                    context->allocate_temp(DataTypeToEnum<T>::value,
920                                           col_buffer_shape, &col_buffer));
921 
922     // The input offset corresponding to a single input image.
923     const int64 input_offset = dims.spatial_dims[0].input_size *
924                                dims.spatial_dims[1].input_size *
925                                dims.spatial_dims[2].input_size * dims.in_depth;
926     // The output offset corresponding to a single output image.
927     const int64 output_offset =
928         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
929         dims.spatial_dims[2].output_size * dims.out_depth;
930 
931     const T* input_data = input.template flat<T>().data();
932     T* col_buffer_data = col_buffer.template flat<T>().data();
933     const T* out_backprop_data = out_backprop.template flat<T>().data();
934     T* filter_backprop_data = filter_backprop->template flat<T>().data();
935 
936     typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
937                              Eigen::Unaligned>
938         TensorMap;
939     typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
940                              Eigen::Unaligned>
941         ConstTensorMap;
942 
943     TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
944     C.setZero();
945 
946     // Initialize contraction dims (we need to transpose 'A' below).
947     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
948     contract_dims[0].first = 0;
949     contract_dims[0].second = 0;
950 
951     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
952 
953     for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
954       const int shard_limit =
955           std::min(static_cast<int>(shard_size),
956                    static_cast<int>(dims.batch_size) - image_id);
957 
958       auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
959                     &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
960                     &bottom_pad_rows, &right_pad_cols, &input_offset,
961                     &size_A](int64 start, int64 limit) {
962         for (int shard_id = start; shard_id < limit; ++shard_id) {
963           const T* input_data_shard = input_data + shard_id * input_offset;
964           T* col_data_shard = col_buffer_data + shard_id * size_A;
965 
966           // When we compute the gradient with respect to the filters, we need
967           // to do im2col to allow gemm-type computation.
968           Im2col<T>(input_data_shard, dims.in_depth,
969                     // Input spatial dimensions.
970                     dims.spatial_dims[0].input_size,  // input planes
971                     dims.spatial_dims[1].input_size,  // input rows
972                     dims.spatial_dims[2].input_size,  // input cols
973                     // Filter spatial dimensions.
974                     dims.spatial_dims[0].filter_size,  // filter planes
975                     dims.spatial_dims[1].filter_size,  // filter rows
976                     dims.spatial_dims[2].filter_size,  // filter cols
977                     // Spatial padding.
978                     top_pad_planes, top_pad_rows, left_pad_cols,
979                     bottom_pad_planes, bottom_pad_rows, right_pad_cols,
980                     // Spatial striding.
981                     dims.spatial_dims[0].stride,  // stride planes
982                     dims.spatial_dims[1].stride,  // stride rows
983                     dims.spatial_dims[2].stride,  // stride cols
984                     col_data_shard);
985         }
986       };
987       Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
988             size_A, shard);
989 
990       ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
991                        filter_total_size);
992       ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
993                        dims.out_depth);
994 
995       // Gradient with respect to filter.
996       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
997 
998       input_data += input_offset * shard_limit;
999       out_backprop_data += output_offset * shard_limit;
1000     }
1001   }
1002 
1003  private:
1004   std::vector<int32> dilation_;
1005   std::vector<int32> stride_;
1006   Padding padding_;
1007   TensorFormat data_format_;
1008   bool takes_shape_;
1009 
1010   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
1011 };
1012 
1013 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
1014 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
1015 
1016 #define REGISTER_CPU_KERNEL(T)                                                \
1017   REGISTER_KERNEL_BUILDER(                                                    \
1018       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1019       Conv3DCustomBackpropFilterOp<CPUDevice, T>);                            \
1020   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1021                               .Device(DEVICE_CPU)                             \
1022                               .TypeConstraint<T>("T"),                        \
1023                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1024   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1025                               .Device(DEVICE_CPU)                             \
1026                               .Label("custom")                                \
1027                               .TypeConstraint<T>("T"),                        \
1028                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1029   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1030                               .Device(DEVICE_CPU)                             \
1031                               .Label("custom")                                \
1032                               .TypeConstraint<T>("T"),                        \
1033                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1034   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1035                               .Device(DEVICE_CPU)                             \
1036                               .Label("eigen_tensor")                          \
1037                               .TypeConstraint<T>("T"),                        \
1038                           Conv3DBackpropFilterOp<CPUDevice, T>);              \
1039   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1040                               .Device(DEVICE_CPU)                             \
1041                               .Label("eigen_tensor")                          \
1042                               .TypeConstraint<T>("T"),                        \
1043                           Conv3DBackpropFilterOp<CPUDevice, T>);
1044 
1045 TF_CALL_float(REGISTER_CPU_KERNEL);
1046 TF_CALL_double(REGISTER_CPU_KERNEL);
1047 #undef REGISTER_CPU_KERNEL
1048 
1049 // WARNING: Eigen::half is not trivially copyable and can't be used in
1050 // custom backprop filter kernel because of memcpy and memset in Im2col.
1051 #define REGISTER_CPU_KERNEL(T)                                                \
1052   REGISTER_KERNEL_BUILDER(                                                    \
1053       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1054       Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
1055   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1056                               .Device(DEVICE_CPU)                             \
1057                               .TypeConstraint<T>("T"),                        \
1058                           Conv3DBackpropFilterOp<CPUDevice, T>);
1059 
1060 TF_CALL_half(REGISTER_CPU_KERNEL);
1061 #undef REGISTER_CPU_KERNEL
1062 
1063 // GPU definitions of both ops.
1064 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1065 // Forward declarations of the functor specializations for GPU.
1066 // This ensures that the custom implementation is used instead of the default
1067 // Eigen one (which is used for CPU).
1068 namespace functor {
1069 #define DECLARE_GPU_SPEC(T)                                           \
1070   template <>                                                         \
1071   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
1072       const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
1073       typename TTypes<T, 5, int>::ConstTensor in,                     \
1074       typename TTypes<T, 5, int>::Tensor out);                        \
1075   template <>                                                         \
1076   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
1077       const GPUDevice& d, FilterTensorFormat src_filter_format,       \
1078       typename TTypes<T, 5>::ConstTensor in,                          \
1079       typename TTypes<T, 5>::Tensor out);                             \
1080   template <>                                                         \
1081   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
1082       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
1083       const std::array<int, 3>& padding_left,                         \
1084       const std::array<int, 3>& padding_right,                        \
1085       typename TTypes<T, 5, int>::Tensor out, TensorFormat format,    \
1086       const T& padding_value);
1087 
1088 DECLARE_GPU_SPEC(Eigen::half);
1089 DECLARE_GPU_SPEC(float);
1090 DECLARE_GPU_SPEC(double);
1091 #undef DECLARE_GPU_SPEC
1092 }  // namespace functor
1093 
1094 // A dummy type to group backward data autotune results together.
1095 struct Conv3dBackwardDataAutoTuneGroup {
nametensorflow::Conv3dBackwardDataAutoTuneGroup1096   static string name() { return "Conv3dBwdData"; }
1097 };
1098 typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters,
1099                           se::dnn::AlgorithmConfig>
1100 
1101     AutoTuneConv3dBwdData;
1102 template <typename T>
1103 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
1104  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)1105   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
1106       : OpKernel(context),
1107         data_format_(FORMAT_NHWC),
1108         takes_shape_(type_string().find("V2") != std::string::npos) {
1109     // data_format is only available in V2.
1110     if (takes_shape_) {
1111       string data_format;
1112       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1113       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1114                   errors::InvalidArgument("Invalid data format"));
1115     }
1116     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1117     OP_REQUIRES(context, dilation_.size() == 5,
1118                 errors::InvalidArgument("Dilation rates field must "
1119                                         "specify 5 dimensions"));
1120     OP_REQUIRES(context,
1121                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1122                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1123                 errors::InvalidArgument(
1124                     "Current implementation does not yet support "
1125                     "dilation rates in the batch and depth dimensions."));
1126     OP_REQUIRES(
1127         context,
1128         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1129          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1130          GetTensorDim(dilation_, data_format_, '2') > 0),
1131         errors::InvalidArgument("Dilated rates should be larger than 0."));
1132     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1133     OP_REQUIRES(context, stride_.size() == 5,
1134                 errors::InvalidArgument("Sliding window strides field must "
1135                                         "specify 5 dimensions"));
1136     OP_REQUIRES(
1137         context,
1138         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1139          GetTensorDim(stride_, data_format_, 'N') == 1),
1140         errors::InvalidArgument("Current implementation does not yet support "
1141                                 "strides in the batch and depth dimensions."));
1142     OP_REQUIRES(
1143         context,
1144         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1145          GetTensorDim(stride_, data_format_, '1') > 0 &&
1146          GetTensorDim(stride_, data_format_, '2') > 0),
1147         errors::InvalidArgument("Spatial strides should be larger than 0."));
1148     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1149     cudnn_use_autotune_ = CudnnUseAutotune();
1150   }
Compute(OpKernelContext * context)1151   void Compute(OpKernelContext* context) override {
1152     const Tensor& filter = context->input(1);
1153     const TensorShape& filter_shape = filter.shape();
1154 
1155     const Tensor& out_backprop = context->input(2);
1156     const TensorShape& out_backprop_shape = out_backprop.shape();
1157 
1158     TensorShape input_shape;
1159     if (takes_shape_) {
1160       const Tensor& input_sizes = context->input(0);
1161       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
1162     } else {
1163       input_shape = context->input(0).shape();
1164     }
1165 
1166     ConvBackpropDimensions dims;
1167     OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
1168                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
1169                                 input_shape, filter_shape, out_backprop_shape,
1170                                 dilation_, stride_, padding_,
1171                                 /*explicit_paddings=*/{}, data_format_, &dims));
1172 
1173     Tensor* in_backprop;
1174     OP_REQUIRES_OK(context,
1175                    context->allocate_output(0, input_shape, &in_backprop));
1176 
1177     auto* stream = context->op_device_context()->stream();
1178     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1179 
1180     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1181     if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
1182         dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
1183         dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
1184         dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
1185         dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
1186       const uint64 m = dims.batch_size * dims.input_size(0) *
1187                        dims.input_size(1) * dims.input_size(2);
1188       const uint64 k = dims.out_depth;
1189       const uint64 n = dims.in_depth;
1190 
1191       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1192                                   out_backprop.template flat<T>().size());
1193       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1194                                   filter.template flat<T>().size());
1195       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1196                                   in_backprop->template flat<T>().size());
1197 
1198       auto transpose = se::blas::Transpose::kTranspose;
1199       auto no_transpose = se::blas::Transpose::kNoTranspose;
1200 
1201       bool blas_launch_status =
1202           stream
1203               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
1204                              a_ptr, k, 0.0f, &c_ptr, n)
1205               .ok();
1206       if (!blas_launch_status) {
1207         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
1208                                             ", n=", n, ", k=", k));
1209       }
1210       return;
1211     } else if (!is_grouped_convolution &&
1212                dims.filter_size(0) == dims.input_size(0) &&
1213                dims.filter_size(1) == dims.input_size(1) &&
1214                dims.filter_size(2) == dims.input_size(2) &&
1215                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1216       const uint64 m = dims.batch_size;
1217       const uint64 k = dims.out_depth;
1218       const uint64 n = dims.input_size(0) * dims.input_size(1) *
1219                        dims.input_size(2) * dims.in_depth;
1220 
1221       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1222                                   out_backprop.template flat<T>().size());
1223       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1224                                   filter.template flat<T>().size());
1225       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1226                                   in_backprop->template flat<T>().size());
1227 
1228       auto transpose = se::blas::Transpose::kTranspose;
1229       auto no_transpose = se::blas::Transpose::kNoTranspose;
1230 
1231       bool blas_launch_status =
1232           stream
1233               ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
1234                              a_ptr, k, 0.0f, &c_ptr, n)
1235               .ok();
1236       if (!blas_launch_status) {
1237         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
1238                                             ", n=", n, ", k=", k));
1239       }
1240       return;
1241     }
1242 
1243     int padding_planes = dims.SpatialPadding(padding_, 0);
1244     int padding_rows = dims.SpatialPadding(padding_, 1);
1245     int padding_cols = dims.SpatialPadding(padding_, 2);
1246     const bool planes_odd = (padding_planes % 2 != 0);
1247     const bool rows_odd = (padding_rows % 2 != 0);
1248     const bool cols_odd = (padding_cols % 2 != 0);
1249 
1250     TensorShape compatible_input_shape;
1251     if (rows_odd || cols_odd || planes_odd) {
1252       // cuDNN only supports the same amount of padding on both sides.
1253       compatible_input_shape = {
1254           dims.batch_size,
1255           dims.in_depth,
1256           dims.input_size(0) + planes_odd,
1257           dims.input_size(1) + rows_odd,
1258           dims.input_size(2) + cols_odd,
1259       };
1260     } else {
1261       compatible_input_shape = {dims.batch_size, dims.in_depth,
1262                                 dims.input_size(0), dims.input_size(1),
1263                                 dims.input_size(2)};
1264     }
1265 
1266     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1267         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1268         << ", " << padding_planes << ")";
1269 
1270 #if GOOGLE_CUDA
1271     const bool compute_in_nhwc =
1272         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1273 #else
1274     // fast NDHWC implementation is a CUDA only feature
1275     const bool compute_in_nhwc = false;
1276 #endif
1277     const TensorFormat compute_data_format =
1278         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1279                                                          : FORMAT_NCHW;
1280 
1281     VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
1282             << " data_format=" << ToString(data_format_)
1283             << " compute_data_format=" << ToString(compute_data_format);
1284 
1285     constexpr auto kComputeInNHWC =
1286         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1287                         se::dnn::FilterLayout::kOutputYXInput);
1288     constexpr auto kComputeInNCHW =
1289         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1290                         se::dnn::FilterLayout::kOutputInputYX);
1291 
1292     se::dnn::DataLayout compute_data_layout;
1293     se::dnn::FilterLayout filter_layout;
1294 
1295     std::tie(compute_data_layout, filter_layout) =
1296         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1297 
1298     se::dnn::BatchDescriptor input_desc(3);
1299     input_desc.set_count(dims.batch_size)
1300         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
1301         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
1302         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
1303         .set_feature_map_count(dims.in_depth)
1304         .set_layout(compute_data_layout);
1305     se::dnn::BatchDescriptor output_desc(3);
1306     output_desc.set_count(dims.batch_size)
1307         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1308         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1309         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1310         .set_feature_map_count(dims.out_depth)
1311         .set_layout(compute_data_layout);
1312     se::dnn::FilterDescriptor filter_desc(3);
1313     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1314         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1315         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1316         .set_input_feature_map_count(filter_shape.dim_size(3))
1317         .set_output_feature_map_count(filter_shape.dim_size(4))
1318         .set_layout(filter_layout);
1319     se::dnn::ConvolutionDescriptor conv_desc(3);
1320     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1321         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1322         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1323         .set_filter_stride(DimIndex::X, dims.stride(2))
1324         .set_filter_stride(DimIndex::Y, dims.stride(1))
1325         .set_filter_stride(DimIndex::Z, dims.stride(0))
1326         .set_zero_padding(DimIndex::X, padding_cols / 2)
1327         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1328         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1329         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1330 
1331     // Shape: out, in, z, y, x.
1332     Tensor transformed_filter;
1333     auto dst_format =
1334         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1335     TensorShape dst_shape =
1336         dst_format == FORMAT_OIHW
1337             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1338                            dims.filter_size(0), dims.filter_size(1),
1339                            dims.filter_size(2)})
1340             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1341                            dims.filter_size(1), dims.filter_size(2),
1342                            filter_shape.dim_size(3)});
1343     OP_REQUIRES_OK(context,
1344                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1345                                           &transformed_filter));
1346 
1347     functor::TransformFilter<GPUDevice, T, int, 5>()(
1348         context->eigen_device<GPUDevice>(), dst_format,
1349         To32Bit(filter.tensor<T, 5>()),
1350         To32Bit(transformed_filter.tensor<T, 5>()));
1351 
1352     // Shape: batch, filters, z, y, x.
1353     Tensor transformed_out_backprop;
1354     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1355       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1356                                 dims.output_size(0), dims.output_size(1),
1357                                 dims.output_size(2)};
1358       if (dims.out_depth > 1) {
1359         OP_REQUIRES_OK(context, context->allocate_temp(
1360                                     DataTypeToEnum<T>::value, nchw_shape,
1361                                     &transformed_out_backprop));
1362         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1363             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1364             transformed_out_backprop.tensor<T, 5>());
1365       } else {
1366         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1367       }
1368     } else {
1369       transformed_out_backprop = out_backprop;
1370     }
1371     // Shape: batch, filters, z, y, x.
1372     Tensor pre_transformed_in_backprop;
1373     OP_REQUIRES_OK(context,
1374                    context->allocate_temp(
1375                        DataTypeToEnum<T>::value,
1376                        ShapeFromFormat(compute_data_format,
1377                                        compatible_input_shape.dim_size(0),
1378                                        {{compatible_input_shape.dim_size(2),
1379                                          compatible_input_shape.dim_size(3),
1380                                          compatible_input_shape.dim_size(4)}},
1381                                        compatible_input_shape.dim_size(1)),
1382                        &pre_transformed_in_backprop));
1383 
1384     auto out_backprop_ptr =
1385         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1386                        transformed_out_backprop.template flat<T>().size());
1387     auto filter_ptr =
1388         AsDeviceMemory(transformed_filter.template flat<T>().data(),
1389                        transformed_filter.template flat<T>().size());
1390     auto in_backprop_ptr =
1391         AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
1392                        pre_transformed_in_backprop.template flat<T>().size());
1393 
1394     static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
1395         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
1396 
1397     const int device_id = stream->parent()->device_ordinal();
1398     // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
1399     // the dtype from 2nd input, i.e., out_backprop.
1400     DataType dtype = context->input(2).dtype();
1401     const ConvParameters conv_parameters = {
1402         dims.batch_size,
1403         dims.in_depth,
1404         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1405         compute_data_format,
1406         dims.out_depth,
1407         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1408         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1409         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1410         {{padding_planes, padding_rows, padding_cols}},
1411         dtype,
1412         device_id,
1413         conv_desc.group_count()};
1414 
1415     using se::dnn::AlgorithmConfig;
1416     using se::dnn::AlgorithmDesc;
1417     using se::dnn::ProfileResult;
1418 #if TENSORFLOW_USE_ROCM
1419     // cudnn_use_autotune is applicable only the CUDA flow
1420     // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
1421     // if we do not have a cached algorithm_config for this conv_parameters
1422     cudnn_use_autotune_ = true;
1423 #endif
1424     AlgorithmConfig algorithm_config;
1425     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
1426                                    conv_parameters, &algorithm_config)) {
1427 #if GOOGLE_CUDA
1428       se::TfAllocatorAdapter tf_allocator_adapter(
1429           context->device()->GetAllocator({}), stream);
1430       se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
1431                                         se::GpuAsmOpts());
1432       se::DeviceMemory<T> in_backprop_ptr_rz(
1433           WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
1434       std::vector<AlgorithmDesc> algorithms;
1435       CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
1436           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
1437               stream->parent()),
1438           &algorithms));
1439 
1440       std::vector<tensorflow::AutotuneResult> results;
1441       for (const auto& profile_algorithm : algorithms) {
1442         // TODO(zhengxq): profile each algorithm multiple times to better
1443         // accuracy.
1444         DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1445                                               context);
1446         se::RedzoneAllocator rz_scratch_allocator(
1447             stream, &tf_allocator_adapter, se::GpuAsmOpts(),
1448             /*memory_limit=*/ConvolveBackwardDataScratchSize);
1449         se::ScratchAllocator* allocator_used =
1450             !RedzoneCheckDisabled()
1451                 ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
1452                 : static_cast<se::ScratchAllocator*>(&scratch_allocator);
1453         ProfileResult profile_result;
1454         auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1455             filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1456             input_desc, &in_backprop_ptr_rz, allocator_used,
1457             AlgorithmConfig(profile_algorithm), &profile_result);
1458         if (cudnn_launch_status.ok()) {
1459           if (profile_result.is_valid()) {
1460             results.emplace_back();
1461             auto& result = results.back();
1462             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1463             result.mutable_conv()->set_tensor_ops_enabled(
1464                 profile_algorithm.tensor_ops_enabled());
1465             result.set_scratch_bytes(
1466                 !RedzoneCheckDisabled()
1467                     ? rz_scratch_allocator
1468                           .TotalAllocatedBytesExcludingRedzones()
1469                     : scratch_allocator.TotalByteSize());
1470             *result.mutable_run_time() = proto_utils::ToDurationProto(
1471                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1472 
1473             // TODO(george): they don't do results at all??
1474             CheckRedzones(rz_scratch_allocator, &result);
1475             CheckRedzones(rz_allocator, &result);
1476           }
1477         }
1478       }
1479 #elif TENSORFLOW_USE_ROCM
1480       DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1481                                             context);
1482       std::vector<ProfileResult> algorithms;
1483       CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
1484           se::dnn::ConvolutionKind::BACKWARD_DATA,
1485           se::dnn::ToDataType<T>::value, stream, input_desc, in_backprop_ptr,
1486           filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1487           &scratch_allocator, &algorithms));
1488       std::vector<tensorflow::AutotuneResult> results;
1489       for (auto miopen_algorithm : algorithms) {
1490         auto profile_algorithm = miopen_algorithm.algorithm();
1491         ProfileResult profile_result;
1492         auto miopen_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1493             filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1494             input_desc, &in_backprop_ptr, &scratch_allocator,
1495             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
1496             &profile_result);
1497         if (miopen_launch_status.ok()) {
1498           if (profile_result.is_valid()) {
1499             results.emplace_back();
1500             auto& result = results.back();
1501             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1502             result.mutable_conv()->set_tensor_ops_enabled(
1503                 profile_algorithm.tensor_ops_enabled());
1504             result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1505             *result.mutable_run_time() = proto_utils::ToDurationProto(
1506                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1507           }
1508         }
1509       }
1510 #endif
1511       LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
1512                              se::dnn::ToDataType<T>::value, in_backprop_ptr,
1513                              filter_ptr, out_backprop_ptr, input_desc,
1514                              filter_desc, output_desc, conv_desc,
1515                              stream->parent(), results);
1516       OP_REQUIRES_OK(context,
1517                      BestCudnnConvAlgorithm(results, &algorithm_config));
1518       AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
1519                                                    algorithm_config);
1520     }
1521     DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1522                                           context);
1523     auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1524         filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1525         input_desc, &in_backprop_ptr, &scratch_allocator, algorithm_config,
1526         nullptr);
1527 
1528     if (!cudnn_launch_status.ok()) {
1529       context->SetStatus(cudnn_launch_status);
1530     }
1531 
1532     if (rows_odd || cols_odd || planes_odd) {
1533       Tensor in_backprop_remove_padding;
1534       OP_REQUIRES_OK(
1535           context, context->allocate_temp(
1536                        DataTypeToEnum<T>::value,
1537                        ShapeFromFormat(compute_data_format, dims.batch_size,
1538                                        {{dims.input_size(0), dims.input_size(1),
1539                                          dims.input_size(2)}},
1540                                        dims.in_depth),
1541                        &in_backprop_remove_padding));
1542 
1543       // Remove the padding for odd spatial dimensions.
1544       functor::PadInput<GPUDevice, T, int, 5>()(
1545           context->eigen_device<GPUDevice>(),
1546           To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1547                       .tensor<T, 5>()),
1548           {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
1549           To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
1550           compute_data_format, T{});
1551 
1552       pre_transformed_in_backprop = in_backprop_remove_padding;
1553     }
1554 
1555     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1556       auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1557       functor::NCHWToNHWC<GPUDevice, T, 5>()(
1558           context->eigen_device<GPUDevice>(),
1559           toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
1560           in_backprop->tensor<T, 5>());
1561     } else {
1562       *in_backprop = pre_transformed_in_backprop;
1563     }
1564   }
1565 
1566  private:
1567   std::vector<int32> dilation_;
1568   std::vector<int32> stride_;
1569   Padding padding_;
1570   TensorFormat data_format_;
1571   bool takes_shape_;
1572   bool cudnn_use_autotune_;
1573 };
1574 
1575 // A dummy type to group backward filter autotune results together.
1576 struct Conv3dBackwardFilterAutoTuneGroup {
nametensorflow::Conv3dBackwardFilterAutoTuneGroup1577   static string name() { return "Conv3dBwdFilter"; }
1578 };
1579 typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters,
1580                           se::dnn::AlgorithmConfig>
1581     AutoTuneConv3dBwdFilter;
1582 
1583 template <typename T>
1584 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
1585  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)1586   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
1587       : OpKernel(context),
1588         data_format_(FORMAT_NHWC),
1589         takes_shape_(type_string().find("V2") != std::string::npos) {
1590     // data_format is only available in V2.
1591     if (takes_shape_) {
1592       string data_format;
1593       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1594       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1595                   errors::InvalidArgument("Invalid data format"));
1596     }
1597     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1598     OP_REQUIRES(context, dilation_.size() == 5,
1599                 errors::InvalidArgument("Dilation rates field must "
1600                                         "specify 5 dimensions"));
1601     OP_REQUIRES(context,
1602                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1603                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1604                 errors::InvalidArgument(
1605                     "Current implementation does not yet support "
1606                     "dilation rates in the batch and depth dimensions."));
1607     OP_REQUIRES(
1608         context,
1609         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1610          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1611          GetTensorDim(dilation_, data_format_, '2') > 0),
1612         errors::InvalidArgument("Dilated rates should be larger than 0."));
1613     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1614     OP_REQUIRES(context, stride_.size() == 5,
1615                 errors::InvalidArgument("Sliding window strides field must "
1616                                         "specify 5 dimensions"));
1617     OP_REQUIRES(
1618         context,
1619         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1620          GetTensorDim(stride_, data_format_, 'N') == 1),
1621         errors::InvalidArgument("Current implementation does not yet support "
1622                                 "strides in the batch and depth dimensions."));
1623     OP_REQUIRES(
1624         context,
1625         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1626          GetTensorDim(stride_, data_format_, '1') > 0 &&
1627          GetTensorDim(stride_, data_format_, '2') > 0),
1628         errors::InvalidArgument("Spatial strides should be larger than 0."));
1629     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1630     cudnn_use_autotune_ = CudnnUseAutotune();
1631   }
1632 
Compute(OpKernelContext * context)1633   void Compute(OpKernelContext* context) override {
1634     const Tensor& input = context->input(0);
1635     const TensorShape& input_shape = input.shape();
1636 
1637     const Tensor& out_backprop = context->input(2);
1638     const TensorShape& out_backprop_shape = out_backprop.shape();
1639 
1640     TensorShape filter_shape;
1641     if (takes_shape_) {
1642       const Tensor& filter_sizes = context->input(1);
1643       OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
1644     } else {
1645       filter_shape = context->input(1).shape();
1646     }
1647 
1648     ConvBackpropDimensions dims;
1649     OP_REQUIRES_OK(
1650         context,
1651         ConvBackpropComputeDimensionsV2(
1652             "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
1653             filter_shape, out_backprop_shape, dilation_, stride_, padding_,
1654             /*explicit_paddings=*/{}, data_format_, &dims));
1655 
1656     Tensor* filter_backprop;
1657     OP_REQUIRES_OK(context,
1658                    context->allocate_output(0, filter_shape, &filter_backprop));
1659 
1660     auto* stream = context->op_device_context()->stream();
1661     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1662 
1663     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1664     if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
1665         dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
1666         dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
1667         dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
1668         dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
1669       const uint64 m = dims.in_depth;
1670       const uint64 k = dims.batch_size * dims.input_size(1) *
1671                        dims.input_size(2) * dims.input_size(0);
1672       const uint64 n = dims.out_depth;
1673 
1674       // The shape of output backprop is
1675       //   [batch, out_z, out_y, out_x, out_depth]
1676       // From cublas's perspective, it is: n x k
1677       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1678                                   out_backprop.template flat<T>().size());
1679 
1680       // The shape of input is:
1681       //   [batch, in_z, in_y, in_x, in_depth],
1682       // From cublas's perspective, it is: m x k
1683       auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
1684                                   input.template flat<T>().size());
1685 
1686       // The shape of the filter backprop is:
1687       //   [1, 1, 1, in_depth, out_depth]
1688       // From cublas's perspective, it is: n x m
1689       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1690                                   filter_backprop->template flat<T>().size());
1691 
1692       bool blas_launch_status =
1693           stream
1694               ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1695                              se::blas::Transpose::kTranspose, n, m, k, 1.0f,
1696                              a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
1697               .ok();
1698       if (!blas_launch_status) {
1699         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
1700                                             ", n=", n, ", k=", k));
1701       }
1702       return;
1703     } else if (!is_grouped_convolution &&
1704                dims.filter_size(0) == dims.input_size(0) &&
1705                dims.filter_size(1) == dims.input_size(1) &&
1706                dims.filter_size(2) == dims.input_size(2) &&
1707                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1708       const uint64 m = dims.input_size(0) * dims.input_size(1) *
1709                        dims.input_size(2) * dims.in_depth;
1710       const uint64 k = dims.batch_size;
1711       const uint64 n = dims.out_depth;
1712 
1713       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
1714                                   input.template flat<T>().size());
1715       auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1716                                   out_backprop.template flat<T>().size());
1717       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1718                                   filter_backprop->template flat<T>().size());
1719 
1720       bool blas_launch_status =
1721           stream
1722               ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1723                              se::blas::Transpose::kTranspose, n, m, k, 1.0f,
1724                              b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
1725               .ok();
1726       if (!blas_launch_status) {
1727         context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
1728                                             ", n=", n, ", k=", k));
1729       }
1730       return;
1731     }
1732 
1733     int padding_planes = dims.SpatialPadding(padding_, 0);
1734     int padding_rows = dims.SpatialPadding(padding_, 1);
1735     int padding_cols = dims.SpatialPadding(padding_, 2);
1736     const bool planes_odd = (padding_planes % 2 != 0);
1737     const bool rows_odd = (padding_rows % 2 != 0);
1738     const bool cols_odd = (padding_cols % 2 != 0);
1739 
1740     Tensor compatible_input;
1741     if (rows_odd || cols_odd || planes_odd) {
1742       OP_REQUIRES_OK(context,
1743                      context->allocate_temp(
1744                          DataTypeToEnum<T>::value,
1745                          ShapeFromFormat(data_format_, dims.batch_size,
1746                                          {{dims.input_size(0) + planes_odd,
1747                                            dims.input_size(1) + rows_odd,
1748                                            dims.input_size(2) + cols_odd}},
1749                                          dims.in_depth),
1750                          &compatible_input));
1751       functor::PadInput<GPUDevice, T, int, 5>()(
1752           context->template eigen_device<GPUDevice>(),
1753           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
1754           {{planes_odd, rows_odd, cols_odd}},
1755           To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{});
1756     } else {
1757       compatible_input = input;
1758     }
1759 
1760     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1761         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1762         << ", " << padding_planes << ")";
1763 
1764 #if GOOGLE_CUDA
1765     const bool compute_in_nhwc =
1766         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1767 #else
1768     // fast NDHWC implementation is a CUDA only feature
1769     const bool compute_in_nhwc = false;
1770 #endif
1771     const TensorFormat compute_data_format =
1772         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1773                                                          : FORMAT_NCHW;
1774 
1775     VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
1776             << " data_format=" << ToString(data_format_)
1777             << " compute_data_format=" << ToString(compute_data_format);
1778 
1779     constexpr auto kComputeInNHWC =
1780         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1781                         se::dnn::FilterLayout::kOutputYXInput);
1782     constexpr auto kComputeInNCHW =
1783         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1784                         se::dnn::FilterLayout::kOutputInputYX);
1785 
1786     se::dnn::DataLayout compute_data_layout;
1787     se::dnn::FilterLayout filter_layout;
1788 
1789     std::tie(compute_data_layout, filter_layout) =
1790         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1791 
1792     se::dnn::BatchDescriptor input_desc(3);
1793     input_desc.set_count(dims.batch_size)
1794         .set_spatial_dim(DimIndex::X,
1795                          GetTensorDim(compatible_input, data_format_, '2'))
1796         .set_spatial_dim(DimIndex::Y,
1797                          GetTensorDim(compatible_input, data_format_, '1'))
1798         .set_spatial_dim(DimIndex::Z,
1799                          GetTensorDim(compatible_input, data_format_, '0'))
1800         .set_feature_map_count(dims.in_depth)
1801         .set_layout(compute_data_layout);
1802     se::dnn::BatchDescriptor output_desc(3);
1803     output_desc.set_count(dims.batch_size)
1804         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1805         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1806         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1807         .set_feature_map_count(dims.out_depth)
1808         .set_layout(compute_data_layout);
1809     se::dnn::FilterDescriptor filter_desc(3);
1810     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1811         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1812         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1813         .set_input_feature_map_count(filter_shape.dim_size(3))
1814         .set_output_feature_map_count(filter_shape.dim_size(4))
1815         .set_layout(filter_layout);
1816     se::dnn::ConvolutionDescriptor conv_desc(3);
1817     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1818         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1819         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1820         .set_filter_stride(DimIndex::X, dims.stride(2))
1821         .set_filter_stride(DimIndex::Y, dims.stride(1))
1822         .set_filter_stride(DimIndex::Z, dims.stride(0))
1823         .set_zero_padding(DimIndex::X, padding_cols / 2)
1824         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1825         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1826         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1827 
1828     Tensor pre_transformed_filter_backprop;
1829     auto dst_format =
1830         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1831     TensorShape dst_shape =
1832         dst_format == FORMAT_OIHW
1833             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1834                            dims.filter_size(0), dims.filter_size(1),
1835                            dims.filter_size(2)})
1836             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1837                            dims.filter_size(1), dims.filter_size(2),
1838                            filter_shape.dim_size(3)});
1839     OP_REQUIRES_OK(context,
1840                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1841                                           &pre_transformed_filter_backprop));
1842 
1843     Tensor transformed_out_backprop;
1844     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1845       VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
1846       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1847                                 dims.output_size(0), dims.output_size(1),
1848                                 dims.output_size(2)};
1849       OP_REQUIRES_OK(
1850           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
1851                                           &transformed_out_backprop));
1852       if (dims.out_depth > 1) {
1853         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1854             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1855             transformed_out_backprop.tensor<T, 5>());
1856       } else {
1857         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1858       }
1859     } else {
1860       transformed_out_backprop = out_backprop;
1861     }
1862     Tensor transformed_input;
1863     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1864       VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
1865       TensorShape nchw_shape = {
1866           dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
1867           compatible_input.dim_size(2), compatible_input.dim_size(3)};
1868       if (dims.in_depth > 1) {
1869         OP_REQUIRES_OK(context,
1870                        context->allocate_temp(DataTypeToEnum<T>::value,
1871                                               nchw_shape, &transformed_input));
1872         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1873             context->eigen_device<GPUDevice>(),
1874             const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
1875             transformed_input.tensor<T, 5>());
1876       } else {
1877         CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
1878       }
1879     } else {
1880       transformed_input = compatible_input;
1881     }
1882 
1883     auto out_backprop_ptr =
1884         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1885                        transformed_out_backprop.template flat<T>().size());
1886     auto filter_backprop_ptr = AsDeviceMemory(
1887         pre_transformed_filter_backprop.template flat<T>().data(),
1888         pre_transformed_filter_backprop.template flat<T>().size());
1889     auto input_ptr =
1890         AsDeviceMemory(transformed_input.template flat<T>().data(),
1891                        transformed_input.template flat<T>().size());
1892 
1893     static int64 ConvolveBackwardFilterScratchSize = GetDnnWorkspaceLimit(
1894         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
1895 
1896     const int device_id = stream->parent()->device_ordinal();
1897     DataType dtype = input.dtype();
1898     const ConvParameters conv_parameters = {
1899         dims.batch_size,
1900         dims.in_depth,
1901         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1902         compute_data_format,
1903         dims.out_depth,
1904         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1905         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1906         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1907         {{padding_planes, padding_rows, padding_cols}},
1908         dtype,
1909         device_id,
1910         conv_desc.group_count()};
1911 
1912     using se::dnn::AlgorithmConfig;
1913     using se::dnn::AlgorithmDesc;
1914     using se::dnn::ProfileResult;
1915 #if TENSORFLOW_USE_ROCM
1916     // cudnn_use_autotune is applicable only the CUDA flow
1917     // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
1918     // if we do not have a cached algorithm_config for this conv_parameters
1919     cudnn_use_autotune_ = true;
1920 #endif
1921     AlgorithmConfig algorithm_config;
1922     if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
1923                                    conv_parameters, &algorithm_config)) {
1924 #if GOOGLE_CUDA
1925       std::vector<AlgorithmDesc> algorithms;
1926       CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
1927           conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
1928               stream->parent()),
1929           &algorithms));
1930 
1931       std::vector<tensorflow::AutotuneResult> results;
1932       for (const auto& profile_algorithm : algorithms) {
1933         // TODO(zhengxq): profile each algorithm multiple times to better
1934         // accuracy.
1935         DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1936                                               context);
1937         ProfileResult profile_result;
1938         auto cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
1939             input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
1940             filter_desc, &filter_backprop_ptr, &scratch_allocator,
1941             AlgorithmConfig(profile_algorithm), &profile_result);
1942         if (cudnn_launch_status.ok()) {
1943           if (profile_result.is_valid()) {
1944             results.emplace_back();
1945             auto& result = results.back();
1946             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1947             result.mutable_conv()->set_tensor_ops_enabled(
1948                 profile_algorithm.tensor_ops_enabled());
1949             result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1950             *result.mutable_run_time() = proto_utils::ToDurationProto(
1951                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1952           }
1953         }
1954       }
1955 #elif TENSORFLOW_USE_ROCM
1956       DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1957                                             context);
1958       std::vector<ProfileResult> algorithms;
1959       CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
1960           se::dnn::ConvolutionKind::BACKWARD_FILTER,
1961           se::dnn::ToDataType<T>::value, stream, input_desc, input_ptr,
1962           filter_desc, filter_backprop_ptr, output_desc, out_backprop_ptr,
1963           conv_desc, &scratch_allocator, &algorithms));
1964 
1965       std::vector<tensorflow::AutotuneResult> results;
1966       for (auto miopen_algorithm : algorithms) {
1967         auto profile_algorithm = miopen_algorithm.algorithm();
1968         ProfileResult profile_result;
1969         auto cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
1970             input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
1971             filter_desc, &filter_backprop_ptr, &scratch_allocator,
1972             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
1973             &profile_result);
1974         if (cudnn_launch_status.ok()) {
1975           if (profile_result.is_valid()) {
1976             results.emplace_back();
1977             auto& result = results.back();
1978             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1979             result.mutable_conv()->set_tensor_ops_enabled(
1980                 profile_algorithm.tensor_ops_enabled());
1981             result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1982             *result.mutable_run_time() = proto_utils::ToDurationProto(
1983                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1984           }
1985         }
1986       }
1987 #endif
1988       LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_FILTER,
1989                              se::dnn::ToDataType<T>::value, input_ptr,
1990                              filter_backprop_ptr, out_backprop_ptr, input_desc,
1991                              filter_desc, output_desc, conv_desc,
1992                              stream->parent(), results);
1993       Status s = BestCudnnConvAlgorithm(results, &algorithm_config);
1994 #if GOOGLE_CUDA
1995       if (s.code() == error::NOT_FOUND) {
1996         size_t version = cudnnGetVersion();
1997         // For cuDNN 8.0.3 and 8.0.4, no cudnnConvolutionBwdFilterAlgo_t will
1998         // work in certain cases. In such cases we improve the error message.
1999         // This is fixed in cuDNN 8.0.5. For more context, see:
2000         // https://github.com/tensorflow/tensorflow/issues/46589
2001         if (version == 8003 || version == 8004) {
2002           std::string version_str = (version == 8003 ? "8.0.3" : "8.0.4");
2003           s = errors::NotFound(
2004               "No algorithm worked! Please try upgrading to cuDNN 8.0.5. You "
2005               "are using cuDNN ",
2006               version_str, ", which has a bug causing this error.");
2007         }
2008       }
2009 #endif
2010       OP_REQUIRES_OK(context, s);
2011       AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
2012                                                      algorithm_config);
2013     }
2014     DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
2015                                           context);
2016     auto cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
2017         input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2018         filter_desc, &filter_backprop_ptr, &scratch_allocator, algorithm_config,
2019         nullptr);
2020 
2021     if (!cudnn_launch_status.ok()) {
2022       context->SetStatus(cudnn_launch_status);
2023     }
2024 
2025     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
2026     functor::ReverseTransformFilter<GPUDevice, T, 5>()(
2027         context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
2028         toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
2029         filter_backprop->tensor<T, 5>());
2030   }
2031 
2032  private:
2033   std::vector<int32> dilation_;
2034   std::vector<int32> stride_;
2035   Padding padding_;
2036   TensorFormat data_format_;
2037   bool takes_shape_;
2038   bool cudnn_use_autotune_;
2039 };
2040 
2041 #define REGISTER_GPU_KERNEL(T)                                                \
2042   REGISTER_KERNEL_BUILDER(                                                    \
2043       Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
2044       Conv3DBackpropInputOp<GPUDevice, T>);                                   \
2045   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
2046                               .Device(DEVICE_GPU)                             \
2047                               .TypeConstraint<T>("T")                         \
2048                               .HostMemory("input_sizes"),                     \
2049                           Conv3DBackpropInputOp<GPUDevice, T>);               \
2050   REGISTER_KERNEL_BUILDER(                                                    \
2051       Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2052       Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
2053   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
2054                               .Device(DEVICE_GPU)                             \
2055                               .TypeConstraint<T>("T")                         \
2056                               .HostMemory("filter_sizes"),                    \
2057                           Conv3DBackpropFilterOp<GPUDevice, T>);
2058 TF_CALL_half(REGISTER_GPU_KERNEL);
2059 TF_CALL_float(REGISTER_GPU_KERNEL);
2060 TF_CALL_double(REGISTER_GPU_KERNEL);
2061 #undef REGISTER_GPU_KERNEL
2062 
2063 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2064 
2065 }  // namespace tensorflow
2066