• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/framework/kernel_shape_util.h"
20 #include "tensorflow/core/framework/numeric_op.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_slice.h"
26 #include "tensorflow/core/framework/tensor_util.h"
27 #include "tensorflow/core/kernels/conv_2d.h"
28 #include "tensorflow/core/kernels/conv_3d.h"
29 #include "tensorflow/core/kernels/conv_grad_ops.h"
30 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
31 #include "tensorflow/core/kernels/conv_ops_gpu.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
35 #include "tensorflow/core/util/padding.h"
36 #include "tensorflow/core/util/tensor_format.h"
37 #include "tensorflow/core/util/use_cudnn.h"
38 #include "tensorflow/core/util/work_sharder.h"
39 
40 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
41 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
42 #endif
43 
44 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
45 #include "tensorflow/core/platform/stream_executor.h"
46 using stream_executor::dnn::DimIndex;
47 #include "tensorflow/core/protobuf/autotuning.pb.h"
48 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
49 #include "tensorflow/core/util/proto/proto_utils.h"
50 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #if GOOGLE_CUDA
52 #include "third_party/gpus/cudnn/cudnn.h"
53 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
54 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
55 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
56 #endif  // GOOGLE_CUDA
57 
58 namespace {
59 
60 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
61 // conv_grad_input_ops_3d.cc.
62 
63 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
64 
65 // "Depth" is already used for the channel dimension, so for the third spatial
66 // dimension in this file we use "plane", although in NDHWC layout it's
67 // indicated with a "D".
68 
69 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
70 // order (planes, height, width, depth), constructed from patches in 'col_data',
71 // which is required to be in storage order (out_planes * out_height *
72 // out_width, filter_planes, filter_height, filter_width, in_depth).
73 //
74 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
75 template <typename T>
Col2im(const T * col_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * im_data)76 void Col2im(const T* col_data, const int depth, const int planes,
77             const int height, const int width, const int filter_p,
78             const int filter_h, const int filter_w, const int pad_pt,
79             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
80             const int pad_r, const int stride_p, const int stride_h,
81             const int stride_w, T* im_data) {
82   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
83   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
84   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
85   int p_pad = -pad_pt;
86   for (int p = 0; p < planes_col; ++p) {
87     int h_pad = -pad_t;
88     for (int h = 0; h < height_col; ++h) {
89       int w_pad = -pad_l;
90       for (int w = 0; w < width_col; ++w) {
91         T* im_patch_data =
92             im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
93         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
94           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
95             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
96               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
97                   iw < width) {
98                 for (int i = 0; i < depth; ++i) {
99                   im_patch_data[i] += col_data[i];
100                 }
101               }
102               im_patch_data += depth;
103               col_data += depth;
104             }
105             // Jump over remaining number of depth.
106             im_patch_data += depth * (width - filter_w);
107           }
108           // Jump over remaining number of (depth * width).
109           im_patch_data += (depth * width) * (height - filter_h);
110         }
111         w_pad += stride_w;
112       }
113       h_pad += stride_h;
114     }
115     p_pad += stride_p;
116   }
117 }
118 
119 // Returns in 'col_data', image patches in storage order (planes, height, width,
120 // depth) extracted from image at 'input_data', which is required to be in
121 // storage order (batch, planes, height, width, depth).
122 //
123 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
124 template <typename T>
Im2col(const T * input_data,const int depth,const int planes,const int height,const int width,const int filter_p,const int filter_h,const int filter_w,const int pad_pt,const int pad_t,const int pad_l,const int pad_pb,const int pad_b,const int pad_r,const int stride_p,const int stride_h,const int stride_w,T * col_data)125 void Im2col(const T* input_data, const int depth, const int planes,
126             const int height, const int width, const int filter_p,
127             const int filter_h, const int filter_w, const int pad_pt,
128             const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
129             const int pad_r, const int stride_p, const int stride_h,
130             const int stride_w, T* col_data) {
131   const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
132   const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
133   const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
134 
135   int p_pad = -pad_pt;
136   for (int p = 0; p < planes_col; ++p) {
137     int h_pad = -pad_t;
138     for (int h = 0; h < height_col; ++h) {
139       int w_pad = -pad_l;
140       for (int w = 0; w < width_col; ++w) {
141         for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
142           for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
143             for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
144               if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
145                   iw < width) {
146                 memcpy(col_data,
147                        input_data +
148                            (ip * height * width + ih * width + iw) * depth,
149                        sizeof(T) * depth);
150               } else {
151                 // This should be simply padded with zero.
152                 memset(col_data, 0, sizeof(T) * depth);
153               }
154               col_data += depth;
155             }
156           }
157         }
158         w_pad += stride_w;
159       }
160       h_pad += stride_h;
161     }
162     p_pad += stride_p;
163   }
164 }
165 
166 }  // namespace
167 
168 namespace tensorflow {
169 
170 typedef Eigen::ThreadPoolDevice CPUDevice;
171 typedef Eigen::GpuDevice GPUDevice;
172 
173 // Backprop for input that offloads computation to
174 // Eigen::CuboidConvolutionBackwardInput.
175 template <typename Device, class T>
176 class Conv3DBackpropInputOp : public OpKernel {
177  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)178   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
179       : OpKernel(context),
180         data_format_(FORMAT_NHWC),
181         takes_shape_(type_string().find("V2") != std::string::npos) {
182     // data_format is only available in V2.
183     if (takes_shape_) {
184       string data_format;
185       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
186       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
187                   errors::InvalidArgument("Invalid data format"));
188       OP_REQUIRES(
189           context, data_format_ == FORMAT_NHWC,
190           errors::InvalidArgument(
191               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
192     }
193 
194     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
195     OP_REQUIRES(context, dilation_.size() == 5,
196                 errors::InvalidArgument("Dilation rates field must "
197                                         "specify 5 dimensions"));
198     OP_REQUIRES(context,
199                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
200                  GetTensorDim(dilation_, data_format_, 'N') == 1),
201                 errors::InvalidArgument(
202                     "Current implementation does not yet support "
203                     "dilation rates in the batch and depth dimensions."));
204 
205     // TODO(yangzihao): Add CPU version of dilated conv 3D.
206     OP_REQUIRES(context,
207                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
208                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
209                  GetTensorDim(dilation_, data_format_, '2') == 1),
210                 errors::InvalidArgument(
211                     "Current CPU implementation does not yet support "
212                     "dilation rates larger than 1."));
213 
214     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
215     OP_REQUIRES(context, stride_.size() == 5,
216                 errors::InvalidArgument("Sliding window strides field must "
217                                         "specify 5 dimensions"));
218     OP_REQUIRES(
219         context,
220         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
221          GetTensorDim(stride_, data_format_, 'N') == 1),
222         errors::InvalidArgument("Current implementation does not yet support "
223                                 "strides in the batch and depth dimensions."));
224     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
225   }
226 
Compute(OpKernelContext * context)227   void Compute(OpKernelContext* context) override {
228     const Tensor& filter = context->input(1);
229     const TensorShape& filter_shape = filter.shape();
230 
231     const Tensor& out_backprop = context->input(2);
232     const TensorShape& out_backprop_shape = out_backprop.shape();
233 
234     TensorShape input_shape;
235     if (takes_shape_) {
236       const Tensor& input_sizes = context->input(0);
237       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
238       // input_sizes.
239       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
240     } else {
241       input_shape = context->input(0).shape();
242     }
243 
244     OP_REQUIRES(context, input_shape.dims() == 5,
245                 errors::InvalidArgument("input tensor must have 5 dimensions"));
246     OP_REQUIRES(
247         context, filter_shape.dims() == 5,
248         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
249     OP_REQUIRES(
250         context, out_backprop_shape.dims() == 5,
251         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
252     OP_REQUIRES(
253         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
254         errors::InvalidArgument("input and filter_sizes must have the same "
255                                 "number of channels. Got ",
256                                 input_shape.dim_size(4), " for input and ",
257                                 filter_shape.dim_size(3), " for filter_sizes"));
258     OP_REQUIRES(
259         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
260         errors::InvalidArgument("out_backprop and filter_sizes must have the "
261                                 "same number of channels. Got ",
262                                 out_backprop_shape.dim_size(4),
263                                 " for out_backprop and ",
264                                 filter_shape.dim_size(4), " for filter_sizes"));
265 
266     ConvBackpropDimensions dims;
267     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
268                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
269                                 input_shape, filter_shape, out_backprop_shape,
270                                 stride_, padding_, data_format_, &dims));
271 
272     Tensor* in_backprop;
273     OP_REQUIRES_OK(context,
274                    context->allocate_output(0, input_shape, &in_backprop));
275 
276     functor::CuboidConvolutionBackwardInput<Device, T>()(
277         context->eigen_device<Device>(),
278         in_backprop->tensor<T, 5>(),                     // input_backward
279         filter.tensor<T, 5>(),                           // filter
280         out_backprop.tensor<T, 5>(),                     // output_backward
281         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
282         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
283         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
284   }
285 
286  private:
287   std::vector<int32> dilation_;
288   std::vector<int32> stride_;
289   Padding padding_;
290   TensorFormat data_format_;
291   bool takes_shape_;
292 
293   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
294 };
295 
296 // Custom backprop for input that explicitly does the work sharding and calls
297 // Eigen only to multiply matrices.
298 template <typename Device, class T>
299 class Conv3DCustomBackpropInputOp : public OpKernel {
300   // Limit the maximum size of allocated temporary buffer to
301   // kMaxTempAllocationOverhead times the size of the input tensors (input,
302   // filter, out_backprop). If the size of the temporary buffer exceeds this
303   // limit, fallback on Eigen implementation.
304   static constexpr int kMaxTempAllocationOverhead = 25;
305 
306  public:
Conv3DCustomBackpropInputOp(OpKernelConstruction * context)307   explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
308       : OpKernel(context),
309         data_format_(FORMAT_NHWC),
310         takes_shape_(type_string().find("V2") != std::string::npos) {
311     // data_format is only available in V2.
312     if (takes_shape_) {
313       string data_format;
314       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
315       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
316                   errors::InvalidArgument("Invalid data format"));
317       OP_REQUIRES(
318           context, data_format_ == FORMAT_NHWC,
319           errors::InvalidArgument(
320               "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
321     }
322 
323     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
324     OP_REQUIRES(context, dilation_.size() == 5,
325                 errors::InvalidArgument("Dilation rates field must "
326                                         "specify 5 dimensions"));
327     OP_REQUIRES(context,
328                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
329                  GetTensorDim(dilation_, data_format_, 'N') == 1),
330                 errors::InvalidArgument(
331                     "Current implementation does not yet support "
332                     "dilation rates in the batch and depth dimensions."));
333 
334     // TODO(yangzihao): Add CPU version of dilated conv 3D.
335     OP_REQUIRES(context,
336                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
337                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
338                  GetTensorDim(dilation_, data_format_, '2') == 1),
339                 errors::InvalidArgument(
340                     "Current CPU implementation does not yet support "
341                     "dilation rates larger than 1."));
342 
343     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
344     OP_REQUIRES(context, stride_.size() == 5,
345                 errors::InvalidArgument("Sliding window strides field must "
346                                         "specify 5 dimensions"));
347     OP_REQUIRES(
348         context,
349         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
350          GetTensorDim(stride_, data_format_, 'N') == 1),
351         errors::InvalidArgument("Current implementation does not yet support "
352                                 "strides in the batch and depth dimensions."));
353     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
354   }
355 
Compute(OpKernelContext * context)356   void Compute(OpKernelContext* context) override {
357     const Tensor& filter = context->input(1);
358     const TensorShape& filter_shape = filter.shape();
359 
360     const Tensor& out_backprop = context->input(2);
361     const TensorShape& out_backprop_shape = out_backprop.shape();
362 
363     TensorShape input_shape;
364     if (takes_shape_) {
365       const Tensor& input_sizes = context->input(0);
366       // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
367       // input_sizes.
368       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
369     } else {
370       input_shape = context->input(0).shape();
371     }
372 
373     OP_REQUIRES(context, input_shape.dims() == 5,
374                 errors::InvalidArgument("input tensor must have 5 dimensions"));
375     OP_REQUIRES(
376         context, filter_shape.dims() == 5,
377         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
378     OP_REQUIRES(
379         context, out_backprop_shape.dims() == 5,
380         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
381     OP_REQUIRES(
382         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
383         errors::InvalidArgument("input and filter_sizes must have the same "
384                                 "number of channels. Got ",
385                                 input_shape.dim_size(4), " for input and ",
386                                 filter_shape.dim_size(3), " for filter_sizes"));
387     OP_REQUIRES(
388         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
389         errors::InvalidArgument("out_backprop and filter_sizes must have the "
390                                 "same number of channels. Got ",
391                                 out_backprop_shape.dim_size(4),
392                                 " for out_backprop and ",
393                                 filter_shape.dim_size(4), " for filter_sizes"));
394 
395     ConvBackpropDimensions dims;
396     OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
397                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
398                                 input_shape, filter_shape, out_backprop_shape,
399                                 stride_, padding_, data_format_, &dims));
400 
401     Tensor* in_backprop;
402     OP_REQUIRES_OK(context,
403                    context->allocate_output(0, input_shape, &in_backprop));
404 
405     int64_t top_pad_planes, bottom_pad_planes;
406     int64_t top_pad_rows, bottom_pad_rows;
407     int64_t left_pad_cols, right_pad_cols;
408 
409     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
410                                 dims.spatial_dims[0].input_size,
411                                 dims.spatial_dims[0].filter_size,
412                                 dims.spatial_dims[0].stride, padding_,
413                                 &dims.spatial_dims[0].output_size,
414                                 &top_pad_planes, &bottom_pad_planes));
415     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
416                                 dims.spatial_dims[1].input_size,
417                                 dims.spatial_dims[1].filter_size,
418                                 dims.spatial_dims[1].stride, padding_,
419                                 &dims.spatial_dims[1].output_size,
420                                 &top_pad_rows, &bottom_pad_rows));
421     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
422                                 dims.spatial_dims[2].input_size,
423                                 dims.spatial_dims[2].filter_size,
424                                 dims.spatial_dims[2].stride, padding_,
425                                 &dims.spatial_dims[2].output_size,
426                                 &left_pad_cols, &right_pad_cols));
427 
428     // TODO(ezhulenev): Extract work size and shard estimation to shared
429     // functions in conv_grad_ops, and update 2d convolution backprop.
430 
431     // The total dimension size of each kernel.
432     const int64_t filter_total_size =
433         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
434         dims.spatial_dims[2].filter_size * dims.in_depth;
435 
436     // The output image size is the spatial size of the output.
437     const int64_t output_image_size = dims.spatial_dims[0].output_size *
438                                       dims.spatial_dims[1].output_size *
439                                       dims.spatial_dims[2].output_size;
440 
441     const auto cache_sizes = Eigen::internal::CacheSizes();
442     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
443 
444     // Use L3 cache size as target working set size.
445     const size_t target_working_set_size = l3_cache_size / sizeof(T);
446 
447     // Calculate size of matrices involved in MatMul: C = A x B.
448     const int64_t size_A = output_image_size * dims.out_depth;
449 
450     const int64_t size_B = filter_total_size * dims.out_depth;
451 
452     const int64_t size_C = output_image_size * filter_total_size;
453 
454     const int64_t work_unit_size = size_A + size_B + size_C;
455 
456     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
457 
458     // Use parallel tensor contractions if there is no batching.
459     //
460     // Compared to Conv2D code, this version is missing work size estimation. In
461     // benchmarks I didn't find a case when it's beneficial to run parallel
462     // contraction compared to sharding and matmuls.
463     const bool use_parallel_contraction = dims.batch_size == 1;
464 
465     OP_REQUIRES(
466         context, work_unit_size > 0,
467         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
468                                 "must all have at least 1 element"));
469 
470     const size_t shard_size =
471         use_parallel_contraction
472             ? 1
473             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
474 
475     // Total number of elements in all the tensors used by this kernel.
476     int64_t total_tensor_elements = input_shape.num_elements() +
477                                     filter_shape.num_elements() +
478                                     out_backprop_shape.num_elements();
479 
480     // Shape of the temporary workspace buffer.
481     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
482                                     static_cast<int64>(output_image_size),
483                                     static_cast<int64>(filter_total_size)};
484     int64_t col_buffer_elements = col_buffer_shape.num_elements();
485 
486     // If the temporary allocation overhead is too large, fallback on Eigen
487     // implementation which requires much less memory.
488     int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
489     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
490       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
491                  "col_buffer_overhead="
492               << col_buffer_overhead;
493 
494       functor::CuboidConvolutionBackwardInput<Device, T>()(
495           context->eigen_device<Device>(),
496           in_backprop->tensor<T, 5>(),                     // input_backward
497           filter.tensor<T, 5>(),                           // filter
498           out_backprop.tensor<T, 5>(),                     // output_backward
499           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
500           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
501           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
502 
503       return;
504     }
505 
506     Tensor col_buffer;
507     OP_REQUIRES_OK(context,
508                    context->allocate_temp(DataTypeToEnum<T>::value,
509                                           col_buffer_shape, &col_buffer));
510 
511     // The input offset corresponding to a single input image.
512     const int64_t input_offset =
513         dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
514         dims.spatial_dims[2].input_size * dims.in_depth;
515 
516     // The output offset corresponding to a single output image.
517     const int64_t output_offset =
518         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
519         dims.spatial_dims[2].output_size * dims.out_depth;
520 
521     const T* filter_data = filter.template flat<T>().data();
522     T* col_buffer_data = col_buffer.template flat<T>().data();
523     const T* out_backprop_data = out_backprop.template flat<T>().data();
524 
525     auto in_backprop_flat = in_backprop->template flat<T>();
526     T* input_backprop_data = in_backprop_flat.data();
527     in_backprop_flat.device(context->eigen_device<Device>()) =
528         in_backprop_flat.constant(T(0));
529 
530     if (use_parallel_contraction) {
531       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
532                                Eigen::Unaligned>
533           TensorMap;
534       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
535                                Eigen::Unaligned>
536           ConstTensorMap;
537 
538       // Initialize contraction dims (we need to transpose 'B' below).
539       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
540       contract_dims[0].first = 1;
541       contract_dims[0].second = 1;
542 
543       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
544         // Compute gradient into col_buffer.
545         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
546 
547         ConstTensorMap A(out_backprop_data + output_offset * image_id,
548                          output_image_size, dims.out_depth);
549         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
550 
551         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
552 
553         Col2im<T>(col_buffer_data, dims.in_depth,
554                   // Input spatial dimensions.
555                   dims.spatial_dims[0].input_size,  // input planes
556                   dims.spatial_dims[1].input_size,  // input rows
557                   dims.spatial_dims[2].input_size,  // input cols
558                   // Filter spatial dimensions.
559                   dims.spatial_dims[0].filter_size,  // filter planes
560                   dims.spatial_dims[1].filter_size,  // filter rows
561                   dims.spatial_dims[2].filter_size,  // filter cols
562                   // Spatial padding.
563                   top_pad_planes, top_pad_rows, left_pad_cols,
564                   bottom_pad_planes, bottom_pad_rows, right_pad_cols,
565                   // Spatial striding.
566                   dims.spatial_dims[0].stride,  // stride planes
567                   dims.spatial_dims[1].stride,  // stride rows
568                   dims.spatial_dims[2].stride,  // stride cols
569                   input_backprop_data);
570 
571         input_backprop_data += input_offset;
572       }
573     } else {
574       typedef Eigen::Map<
575           Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
576           MatrixMap;
577       typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
578                                              Eigen::RowMajor>>
579           ConstMatrixMap;
580 
581       for (int image_id = 0; image_id < dims.batch_size;
582            image_id += shard_size) {
583         const int shard_limit =
584             std::min(static_cast<int>(shard_size),
585                      static_cast<int>(dims.batch_size) - image_id);
586 
587         auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
588                       &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
589                       &output_image_size, &filter_total_size,
590                       &input_backprop_data, &col_buffer_data,
591                       &out_backprop_data, &filter_data, &input_offset,
592                       &output_offset, &size_C](int64_t start, int64_t limit) {
593           for (int shard_id = start; shard_id < limit; ++shard_id) {
594             T* im2col_buf = col_buffer_data + shard_id * size_C;
595             T* input_data = input_backprop_data + shard_id * input_offset;
596             const T* out_data = out_backprop_data + shard_id * output_offset;
597 
598             // Compute gradient into 'im2col_buf'.
599             MatrixMap C(im2col_buf, output_image_size, filter_total_size);
600 
601             ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
602             ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
603 
604             C.noalias() = A * B.transpose();
605 
606             Col2im<T>(im2col_buf, dims.in_depth,
607                       // Input spatial dimensions.
608                       dims.spatial_dims[0].input_size,  // input planes
609                       dims.spatial_dims[1].input_size,  // input rows
610                       dims.spatial_dims[2].input_size,  // input cols
611                       // Filter spatial dimensions.
612                       dims.spatial_dims[0].filter_size,  // filter planes
613                       dims.spatial_dims[1].filter_size,  // filter rows
614                       dims.spatial_dims[2].filter_size,  // filter cols
615                       // Spatial padding.
616                       top_pad_planes, top_pad_rows, left_pad_cols,
617                       bottom_pad_planes, bottom_pad_rows, right_pad_cols,
618                       // Spatial striding.
619                       dims.spatial_dims[0].stride,  // stride planes
620                       dims.spatial_dims[1].stride,  // stride rows
621                       dims.spatial_dims[2].stride,  // stride cols
622                       input_data);
623           }
624         };
625         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
626               work_unit_size, shard);
627 
628         input_backprop_data += input_offset * shard_limit;
629         out_backprop_data += output_offset * shard_limit;
630       }
631     }
632   }
633 
634  private:
635   std::vector<int32> dilation_;
636   std::vector<int32> stride_;
637   Padding padding_;
638   TensorFormat data_format_;
639   bool takes_shape_;
640 
641   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
642 };
643 
644 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
645 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
646 
647 #define REGISTER_CPU_KERNEL(T)                                                 \
648   REGISTER_KERNEL_BUILDER(                                                     \
649       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
650       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
651   REGISTER_KERNEL_BUILDER(                                                     \
652       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
653       Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
654   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
655                               .Device(DEVICE_CPU)                              \
656                               .Label("custom")                                 \
657                               .TypeConstraint<T>("T"),                         \
658                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
659   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
660                               .Device(DEVICE_CPU)                              \
661                               .Label("custom")                                 \
662                               .TypeConstraint<T>("T"),                         \
663                           Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
664   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
665                               .Device(DEVICE_CPU)                              \
666                               .Label("eigen_tensor")                           \
667                               .TypeConstraint<T>("T"),                         \
668                           Conv3DBackpropInputOp<CPUDevice, T>);                \
669   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
670                               .Device(DEVICE_CPU)                              \
671                               .Label("eigen_tensor")                           \
672                               .TypeConstraint<T>("T"),                         \
673                           Conv3DBackpropInputOp<CPUDevice, T>);
674 
675 TF_CALL_half(REGISTER_CPU_KERNEL);
676 TF_CALL_float(REGISTER_CPU_KERNEL);
677 TF_CALL_double(REGISTER_CPU_KERNEL);
678 #undef REGISTER_CPU_KERNEL
679 
680 // Backprop for filter that offloads computation to
681 // Eigen::CuboidConvolutionBackwardFilter.
682 template <typename Device, class T>
683 class Conv3DBackpropFilterOp : public OpKernel {
684  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)685   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
686       : OpKernel(context),
687         data_format_(FORMAT_NHWC),
688         takes_shape_(type_string().find("V2") != std::string::npos) {
689     // data_format is only available in V2.
690     if (takes_shape_) {
691       string data_format;
692       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
693       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
694                   errors::InvalidArgument("Invalid data format"));
695       OP_REQUIRES(
696           context, data_format_ == FORMAT_NHWC,
697           errors::InvalidArgument(
698               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
699     }
700 
701     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
702     OP_REQUIRES(context, dilation_.size() == 5,
703                 errors::InvalidArgument("Dilation rates field must "
704                                         "specify 5 dimensions"));
705     OP_REQUIRES(context,
706                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
707                  GetTensorDim(dilation_, data_format_, 'N') == 1),
708                 errors::InvalidArgument(
709                     "Current implementation does not yet support "
710                     "dilation rates in the batch and depth dimensions."));
711 
712     // TODO(yangzihao): Add CPU version of dilated conv 3D.
713     OP_REQUIRES(context,
714                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
715                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
716                  GetTensorDim(dilation_, data_format_, '2') == 1),
717                 errors::InvalidArgument(
718                     "Current CPU implementation does not yet support "
719                     "dilation rates larger than 1."));
720 
721     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
722     OP_REQUIRES(context, stride_.size() == 5,
723                 errors::InvalidArgument("Sliding window strides field must "
724                                         "specify 5 dimensions"));
725     OP_REQUIRES(
726         context,
727         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
728          GetTensorDim(stride_, data_format_, 'N') == 1),
729         errors::InvalidArgument("Current implementation does not yet support "
730                                 "strides in the batch and depth dimensions."));
731     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
732   }
733 
Compute(OpKernelContext * context)734   void Compute(OpKernelContext* context) override {
735     const Tensor& input = context->input(0);
736     const TensorShape& input_shape = input.shape();
737 
738     const Tensor& out_backprop = context->input(2);
739     const TensorShape& out_backprop_shape = out_backprop.shape();
740 
741     TensorShape filter_shape;
742     if (takes_shape_) {
743       const Tensor& filter_sizes = context->input(1);
744       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
745                                   filter_sizes.vec<int32>(), &filter_shape));
746     } else {
747       filter_shape = context->input(1).shape();
748     }
749 
750     OP_REQUIRES(context, input_shape.dims() == 5,
751                 errors::InvalidArgument("input tensor must have 5 dimensions"));
752     OP_REQUIRES(
753         context, filter_shape.dims() == 5,
754         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
755     OP_REQUIRES(
756         context, out_backprop_shape.dims() == 5,
757         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
758     OP_REQUIRES(
759         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
760         errors::InvalidArgument("input and filter_sizes must have the same "
761                                 "number of channels. Got ",
762                                 input_shape.dim_size(4), " for input and ",
763                                 filter_shape.dim_size(3), " for filter_sizes"));
764     OP_REQUIRES(
765         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
766         errors::InvalidArgument("out_backprop and filter_sizes must have the "
767                                 "same number of channels. Got ",
768                                 out_backprop_shape.dim_size(4),
769                                 " for out_backprop and ",
770                                 filter_shape.dim_size(4), " for filter_sizes"));
771 
772     ConvBackpropDimensions dims;
773     OP_REQUIRES_OK(context,
774                    ConvBackpropComputeDimensions(
775                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
776                        input_shape, filter_shape, out_backprop_shape, stride_,
777                        padding_, data_format_, &dims));
778 
779     Tensor* filter_backprop;
780     OP_REQUIRES_OK(context,
781                    context->allocate_output(0, filter_shape, &filter_backprop));
782 
783     if (input_shape.num_elements() == 0) {
784       filter_backprop->template flat<T>().setZero();
785       return;
786     }
787 
788     functor::CuboidConvolutionBackwardFilter<Device, T>()(
789         context->eigen_device<Device>(),
790         filter_backprop->tensor<T, 5>(),                 // filter_backward
791         input.tensor<T, 5>(),                            // input
792         out_backprop.tensor<T, 5>(),                     // output_backward
793         static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
794         static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
795         static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
796   }
797 
798  private:
799   std::vector<int32> dilation_;
800   std::vector<int32> stride_;
801   Padding padding_;
802   TensorFormat data_format_;
803   bool takes_shape_;
804 
805   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
806 };
807 
808 // Custom backprop for filter that explicitly does the work sharding and calls
809 // Eigen only to multiply matrices.
810 template <typename Device, class T>
811 class Conv3DCustomBackpropFilterOp : public OpKernel {
812   // Limit the maximum size of allocated temporary buffer to
813   // kMaxTempAllocationOverhead times the size of the input tensors (input,
814   // filter, out_backprop). If the size of the temporary buffer exceeds this
815   // limit, fallback on Eigen implementation.
816   static constexpr int kMaxTempAllocationOverhead = 25;
817 
818  public:
Conv3DCustomBackpropFilterOp(OpKernelConstruction * context)819   explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
820       : OpKernel(context),
821         data_format_(FORMAT_NHWC),
822         takes_shape_(type_string().find("V2") != std::string::npos) {
823     // data_format is only available in V2.
824     if (takes_shape_) {
825       string data_format;
826       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
827       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
828                   errors::InvalidArgument("Invalid data format"));
829       OP_REQUIRES(
830           context, data_format_ == FORMAT_NHWC,
831           errors::InvalidArgument(
832               "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
833     }
834 
835     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
836     OP_REQUIRES(context, dilation_.size() == 5,
837                 errors::InvalidArgument("Dilation rates field must "
838                                         "specify 5 dimensions"));
839     OP_REQUIRES(context,
840                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
841                  GetTensorDim(dilation_, data_format_, 'N') == 1),
842                 errors::InvalidArgument(
843                     "Current implementation does not yet support "
844                     "dilation rates in the batch and depth dimensions."));
845 
846     // TODO(yangzihao): Add CPU version of dilated conv 3D.
847     OP_REQUIRES(context,
848                 (GetTensorDim(dilation_, data_format_, '0') == 1 &&
849                  GetTensorDim(dilation_, data_format_, '1') == 1 &&
850                  GetTensorDim(dilation_, data_format_, '2') == 1),
851                 errors::InvalidArgument(
852                     "Current CPU implementation does not yet support "
853                     "dilation rates larger than 1."));
854 
855     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
856     OP_REQUIRES(context, stride_.size() == 5,
857                 errors::InvalidArgument("Sliding window strides field must "
858                                         "specify 5 dimensions"));
859     OP_REQUIRES(
860         context,
861         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
862          GetTensorDim(stride_, data_format_, 'N') == 1),
863         errors::InvalidArgument("Current implementation does not yet support "
864                                 "strides in the batch and depth dimensions."));
865     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
866   }
867 
Compute(OpKernelContext * context)868   void Compute(OpKernelContext* context) override {
869     const Tensor& input = context->input(0);
870     const TensorShape& input_shape = input.shape();
871 
872     const Tensor& out_backprop = context->input(2);
873     const TensorShape& out_backprop_shape = out_backprop.shape();
874 
875     TensorShape filter_shape;
876     if (takes_shape_) {
877       const Tensor& filter_sizes = context->input(1);
878       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
879                                   filter_sizes.vec<int32>(), &filter_shape));
880     } else {
881       filter_shape = context->input(1).shape();
882     }
883 
884     OP_REQUIRES(context, input_shape.dims() == 5,
885                 errors::InvalidArgument("input tensor must have 5 dimensions"));
886     OP_REQUIRES(
887         context, filter_shape.dims() == 5,
888         errors::InvalidArgument("filter_sizes tensor must have 5 dimensions"));
889     OP_REQUIRES(
890         context, out_backprop_shape.dims() == 5,
891         errors::InvalidArgument("out_backprop tensor must have 5 dimensions"));
892     OP_REQUIRES(
893         context, input_shape.dim_size(4) == filter_shape.dim_size(3),
894         errors::InvalidArgument("input and filter_sizes must have the same "
895                                 "number of channels. Got ",
896                                 input_shape.dim_size(4), " for input and ",
897                                 filter_shape.dim_size(3), " for filter_sizes"));
898     OP_REQUIRES(
899         context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4),
900         errors::InvalidArgument("out_backprop and filter_sizes must have the "
901                                 "same number of channels. Got ",
902                                 out_backprop_shape.dim_size(4),
903                                 " for out_backprop and ",
904                                 filter_shape.dim_size(4), " for filter_sizes"));
905 
906     ConvBackpropDimensions dims;
907     OP_REQUIRES_OK(context,
908                    ConvBackpropComputeDimensions(
909                        "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
910                        input_shape, filter_shape, out_backprop_shape, stride_,
911                        padding_, data_format_, &dims));
912 
913     Tensor* filter_backprop;
914     OP_REQUIRES_OK(context,
915                    context->allocate_output(0, filter_shape, &filter_backprop));
916 
917     if (input_shape.num_elements() == 0) {
918       filter_backprop->template flat<T>().setZero();
919       return;
920     }
921 
922     int64_t top_pad_planes, bottom_pad_planes;
923     int64_t top_pad_rows, bottom_pad_rows;
924     int64_t left_pad_cols, right_pad_cols;
925 
926     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
927                                 dims.spatial_dims[0].input_size,
928                                 dims.spatial_dims[0].filter_size,
929                                 dims.spatial_dims[0].stride, padding_,
930                                 &dims.spatial_dims[0].output_size,
931                                 &top_pad_planes, &bottom_pad_planes));
932     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
933                                 dims.spatial_dims[1].input_size,
934                                 dims.spatial_dims[1].filter_size,
935                                 dims.spatial_dims[1].stride, padding_,
936                                 &dims.spatial_dims[1].output_size,
937                                 &top_pad_rows, &bottom_pad_rows));
938     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
939                                 dims.spatial_dims[2].input_size,
940                                 dims.spatial_dims[2].filter_size,
941                                 dims.spatial_dims[2].stride, padding_,
942                                 &dims.spatial_dims[2].output_size,
943                                 &left_pad_cols, &right_pad_cols));
944 
945     // TODO(ezhulenev): Extract work size and shard estimation to shared
946     // functions in conv_grad_ops, and update 2d convolution backprop.
947 
948     // The total dimension size of each kernel.
949     const int64_t filter_total_size =
950         dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
951         dims.spatial_dims[2].filter_size * dims.in_depth;
952     // The output image size is the spatial size of the output.
953     const int64_t output_image_size = dims.spatial_dims[0].output_size *
954                                       dims.spatial_dims[1].output_size *
955                                       dims.spatial_dims[2].output_size;
956 
957     // Shard 'batch' images (volumes) into 'shard_size' groups of images
958     // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
959     // dividing the L3 cache size ('target_working_set_size') by the matmul size
960     // of an individual image ('work_unit_size').
961 
962     const auto cache_sizes = Eigen::internal::CacheSizes();
963     const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
964 
965     // TODO(andydavis)
966     // *) Consider reducing 'target_working_set_size' if L3 is shared by
967     //    other concurrently running tensorflow ops.
968     const size_t target_working_set_size = l3_cache_size / sizeof(T);
969 
970     const int64_t size_A = output_image_size * filter_total_size;
971 
972     const int64_t size_B = output_image_size * dims.out_depth;
973 
974     const int64_t size_C = filter_total_size * dims.out_depth;
975 
976     const int64_t work_unit_size = size_A + size_B + size_C;
977 
978     OP_REQUIRES(
979         context, work_unit_size > 0,
980         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
981                                 "must all have at least 1 element"));
982 
983     const size_t shard_size =
984         (target_working_set_size + work_unit_size - 1) / work_unit_size;
985 
986     // Total number of elements in all the tensors used by this kernel.
987     int64_t total_tensor_elements = input_shape.num_elements() +
988                                     filter_shape.num_elements() +
989                                     out_backprop_shape.num_elements();
990 
991     // Shape of the temporary workspace buffer.
992     TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
993                                     static_cast<int64>(output_image_size),
994                                     static_cast<int64>(filter_total_size)};
995     int64_t col_buffer_elements = col_buffer_shape.num_elements();
996 
997     // If the temporary allocation overhead is too large, fallback on Eigen
998     // implementation which requires much less memory.
999     int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements;
1000     if (col_buffer_overhead > kMaxTempAllocationOverhead) {
1001       VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
1002                  "col_buffer_overhead="
1003               << col_buffer_overhead;
1004 
1005       functor::CuboidConvolutionBackwardFilter<Device, T>()(
1006           context->eigen_device<Device>(),
1007           filter_backprop->tensor<T, 5>(),                 // filter_backward
1008           input.tensor<T, 5>(),                            // input
1009           out_backprop.tensor<T, 5>(),                     // output_backward
1010           static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
1011           static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
1012           static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
1013 
1014       return;
1015     }
1016 
1017     Tensor col_buffer;
1018     OP_REQUIRES_OK(context,
1019                    context->allocate_temp(DataTypeToEnum<T>::value,
1020                                           col_buffer_shape, &col_buffer));
1021 
1022     // The input offset corresponding to a single input image.
1023     const int64_t input_offset =
1024         dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size *
1025         dims.spatial_dims[2].input_size * dims.in_depth;
1026     // The output offset corresponding to a single output image.
1027     const int64_t output_offset =
1028         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
1029         dims.spatial_dims[2].output_size * dims.out_depth;
1030 
1031     const T* input_data = input.template flat<T>().data();
1032     T* col_buffer_data = col_buffer.template flat<T>().data();
1033     const T* out_backprop_data = out_backprop.template flat<T>().data();
1034     T* filter_backprop_data = filter_backprop->template flat<T>().data();
1035 
1036     typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
1037                              Eigen::Unaligned>
1038         TensorMap;
1039     typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
1040                              Eigen::Unaligned>
1041         ConstTensorMap;
1042 
1043     TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
1044     C.setZero();
1045 
1046     // Initialize contraction dims (we need to transpose 'A' below).
1047     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
1048     contract_dims[0].first = 0;
1049     contract_dims[0].second = 0;
1050 
1051     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
1052 
1053     for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
1054       const int shard_limit =
1055           std::min(static_cast<int>(shard_size),
1056                    static_cast<int>(dims.batch_size) - image_id);
1057 
1058       auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
1059                     &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
1060                     &bottom_pad_rows, &right_pad_cols, &input_offset,
1061                     &size_A](int64_t start, int64_t limit) {
1062         for (int shard_id = start; shard_id < limit; ++shard_id) {
1063           const T* input_data_shard = input_data + shard_id * input_offset;
1064           T* col_data_shard = col_buffer_data + shard_id * size_A;
1065 
1066           // When we compute the gradient with respect to the filters, we need
1067           // to do im2col to allow gemm-type computation.
1068           Im2col<T>(input_data_shard, dims.in_depth,
1069                     // Input spatial dimensions.
1070                     dims.spatial_dims[0].input_size,  // input planes
1071                     dims.spatial_dims[1].input_size,  // input rows
1072                     dims.spatial_dims[2].input_size,  // input cols
1073                     // Filter spatial dimensions.
1074                     dims.spatial_dims[0].filter_size,  // filter planes
1075                     dims.spatial_dims[1].filter_size,  // filter rows
1076                     dims.spatial_dims[2].filter_size,  // filter cols
1077                     // Spatial padding.
1078                     top_pad_planes, top_pad_rows, left_pad_cols,
1079                     bottom_pad_planes, bottom_pad_rows, right_pad_cols,
1080                     // Spatial striding.
1081                     dims.spatial_dims[0].stride,  // stride planes
1082                     dims.spatial_dims[1].stride,  // stride rows
1083                     dims.spatial_dims[2].stride,  // stride cols
1084                     col_data_shard);
1085         }
1086       };
1087       Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
1088             size_A, shard);
1089 
1090       ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
1091                        filter_total_size);
1092       ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
1093                        dims.out_depth);
1094 
1095       // Gradient with respect to filter.
1096       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
1097 
1098       input_data += input_offset * shard_limit;
1099       out_backprop_data += output_offset * shard_limit;
1100     }
1101   }
1102 
1103  private:
1104   std::vector<int32> dilation_;
1105   std::vector<int32> stride_;
1106   Padding padding_;
1107   TensorFormat data_format_;
1108   bool takes_shape_;
1109 
1110   TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
1111 };
1112 
1113 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
1114 // default Eigen implementation (at the cost of ~2x-8x peak memory usage).
1115 
1116 #define REGISTER_CPU_KERNEL(T)                                                \
1117   REGISTER_KERNEL_BUILDER(                                                    \
1118       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1119       Conv3DCustomBackpropFilterOp<CPUDevice, T>);                            \
1120   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1121                               .Device(DEVICE_CPU)                             \
1122                               .TypeConstraint<T>("T"),                        \
1123                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1124   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1125                               .Device(DEVICE_CPU)                             \
1126                               .Label("custom")                                \
1127                               .TypeConstraint<T>("T"),                        \
1128                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1129   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1130                               .Device(DEVICE_CPU)                             \
1131                               .Label("custom")                                \
1132                               .TypeConstraint<T>("T"),                        \
1133                           Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
1134   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
1135                               .Device(DEVICE_CPU)                             \
1136                               .Label("eigen_tensor")                          \
1137                               .TypeConstraint<T>("T"),                        \
1138                           Conv3DBackpropFilterOp<CPUDevice, T>);              \
1139   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1140                               .Device(DEVICE_CPU)                             \
1141                               .Label("eigen_tensor")                          \
1142                               .TypeConstraint<T>("T"),                        \
1143                           Conv3DBackpropFilterOp<CPUDevice, T>);
1144 
1145 TF_CALL_float(REGISTER_CPU_KERNEL);
1146 TF_CALL_double(REGISTER_CPU_KERNEL);
1147 #undef REGISTER_CPU_KERNEL
1148 
1149 // WARNING: Eigen::half is not trivially copyable and can't be used in
1150 // custom backprop filter kernel because of memcpy and memset in Im2col.
1151 #define REGISTER_CPU_KERNEL(T)                                                \
1152   REGISTER_KERNEL_BUILDER(                                                    \
1153       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
1154       Conv3DBackpropFilterOp<CPUDevice, T>);                                  \
1155   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
1156                               .Device(DEVICE_CPU)                             \
1157                               .TypeConstraint<T>("T"),                        \
1158                           Conv3DBackpropFilterOp<CPUDevice, T>);
1159 
1160 TF_CALL_half(REGISTER_CPU_KERNEL);
1161 #undef REGISTER_CPU_KERNEL
1162 
1163 // GPU definitions of both ops.
1164 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1165 // Forward declarations of the functor specializations for GPU.
1166 // This ensures that the custom implementation is used instead of the default
1167 // Eigen one (which is used for CPU).
1168 namespace functor {
1169 #define DECLARE_GPU_SPEC(T)                                           \
1170   template <>                                                         \
1171   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
1172       const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
1173       typename TTypes<T, 5, int>::ConstTensor in,                     \
1174       typename TTypes<T, 5, int>::Tensor out);                        \
1175   template <>                                                         \
1176   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
1177       const GPUDevice& d, FilterTensorFormat src_filter_format,       \
1178       typename TTypes<T, 5>::ConstTensor in,                          \
1179       typename TTypes<T, 5>::Tensor out);                             \
1180   template <>                                                         \
1181   void PadInput<GPUDevice, T, int, 5>::operator()(                    \
1182       const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
1183       const std::array<int, 3>& padding_left,                         \
1184       const std::array<int, 3>& padding_right,                        \
1185       typename TTypes<T, 5, int>::Tensor out, TensorFormat format,    \
1186       const T& padding_value);
1187 
1188 DECLARE_GPU_SPEC(Eigen::half);
1189 DECLARE_GPU_SPEC(float);
1190 DECLARE_GPU_SPEC(double);
1191 #undef DECLARE_GPU_SPEC
1192 }  // namespace functor
1193 
1194 // A dummy type to group backward data autotune results together.
1195 struct Conv3dBackwardDataAutotuneGroup {
nametensorflow::Conv3dBackwardDataAutotuneGroup1196   static string name() { return "Conv3dBwdData"; }
1197 };
1198 
1199 typedef AutotuneSingleton<Conv3dBackwardDataAutotuneGroup, ConvParameters,
1200                           se::dnn::AlgorithmConfig>
1201 
1202     AutotuneConv3dBwdData;
1203 template <typename T>
1204 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
1205  public:
Conv3DBackpropInputOp(OpKernelConstruction * context)1206   explicit Conv3DBackpropInputOp(OpKernelConstruction* context)
1207       : OpKernel(context),
1208         data_format_(FORMAT_NHWC),
1209         takes_shape_(type_string().find("V2") != std::string::npos) {
1210     // data_format is only available in V2.
1211     if (takes_shape_) {
1212       string data_format;
1213       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1214       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1215                   errors::InvalidArgument("Invalid data format"));
1216     }
1217     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1218     OP_REQUIRES(context, dilation_.size() == 5,
1219                 errors::InvalidArgument("Dilation rates field must "
1220                                         "specify 5 dimensions"));
1221     OP_REQUIRES(context,
1222                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1223                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1224                 errors::InvalidArgument(
1225                     "Current implementation does not yet support "
1226                     "dilation rates in the batch and depth dimensions."));
1227     OP_REQUIRES(
1228         context,
1229         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1230          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1231          GetTensorDim(dilation_, data_format_, '2') > 0),
1232         errors::InvalidArgument("Dilated rates should be larger than 0."));
1233     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1234     OP_REQUIRES(context, stride_.size() == 5,
1235                 errors::InvalidArgument("Sliding window strides field must "
1236                                         "specify 5 dimensions"));
1237     OP_REQUIRES(
1238         context,
1239         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1240          GetTensorDim(stride_, data_format_, 'N') == 1),
1241         errors::InvalidArgument("Current implementation does not yet support "
1242                                 "strides in the batch and depth dimensions."));
1243     OP_REQUIRES(
1244         context,
1245         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1246          GetTensorDim(stride_, data_format_, '1') > 0 &&
1247          GetTensorDim(stride_, data_format_, '2') > 0),
1248         errors::InvalidArgument("Spatial strides should be larger than 0."));
1249     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1250     cudnn_use_autotune_ = CudnnUseAutotune();
1251   }
Compute(OpKernelContext * context)1252   void Compute(OpKernelContext* context) override {
1253     const Tensor& filter = context->input(1);
1254     const TensorShape& filter_shape = filter.shape();
1255 
1256     const Tensor& out_backprop = context->input(2);
1257     const TensorShape& out_backprop_shape = out_backprop.shape();
1258 
1259     TensorShape input_shape;
1260     if (takes_shape_) {
1261       const Tensor& input_sizes = context->input(0);
1262       OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
1263     } else {
1264       input_shape = context->input(0).shape();
1265     }
1266 
1267     ConvBackpropDimensions dims;
1268     OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2(
1269                                 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
1270                                 input_shape, filter_shape, out_backprop_shape,
1271                                 dilation_, stride_, padding_,
1272                                 /*explicit_paddings=*/{}, data_format_, &dims));
1273 
1274     Tensor* in_backprop;
1275     OP_REQUIRES_OK(context,
1276                    context->allocate_output(0, input_shape, &in_backprop));
1277 
1278     auto* stream = context->op_device_context()->stream();
1279     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1280 
1281     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1282     if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
1283         dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
1284         dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
1285         dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
1286         dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
1287       const uint64 m = dims.batch_size * dims.input_size(0) *
1288                        dims.input_size(1) * dims.input_size(2);
1289       const uint64 k = dims.out_depth;
1290       const uint64 n = dims.in_depth;
1291 
1292       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1293                                   out_backprop.template flat<T>().size());
1294       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1295                                   filter.template flat<T>().size());
1296       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1297                                   in_backprop->template flat<T>().size());
1298 
1299       auto transpose = se::blas::Transpose::kTranspose;
1300       auto no_transpose = se::blas::Transpose::kNoTranspose;
1301 
1302       OP_REQUIRES_OK(
1303           context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1304                                         k, a_ptr, k, &c_ptr, n));
1305       return;
1306     } else if (!is_grouped_convolution &&
1307                dims.filter_size(0) == dims.input_size(0) &&
1308                dims.filter_size(1) == dims.input_size(1) &&
1309                dims.filter_size(2) == dims.input_size(2) &&
1310                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1311       const uint64 m = dims.batch_size;
1312       const uint64 k = dims.out_depth;
1313       const uint64 n = dims.input_size(0) * dims.input_size(1) *
1314                        dims.input_size(2) * dims.in_depth;
1315 
1316       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1317                                   out_backprop.template flat<T>().size());
1318       auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
1319                                   filter.template flat<T>().size());
1320       auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
1321                                   in_backprop->template flat<T>().size());
1322 
1323       auto transpose = se::blas::Transpose::kTranspose;
1324       auto no_transpose = se::blas::Transpose::kNoTranspose;
1325 
1326       OP_REQUIRES_OK(
1327           context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr,
1328                                         k, a_ptr, k, &c_ptr, n));
1329       return;
1330     }
1331 
1332     int padding_planes = dims.SpatialPadding(padding_, 0);
1333     int padding_rows = dims.SpatialPadding(padding_, 1);
1334     int padding_cols = dims.SpatialPadding(padding_, 2);
1335     const bool planes_odd = (padding_planes % 2 != 0);
1336     const bool rows_odd = (padding_rows % 2 != 0);
1337     const bool cols_odd = (padding_cols % 2 != 0);
1338 
1339     TensorShape compatible_input_shape;
1340     if (rows_odd || cols_odd || planes_odd) {
1341       // cuDNN only supports the same amount of padding on both sides.
1342       compatible_input_shape = {
1343           dims.batch_size,
1344           dims.in_depth,
1345           dims.input_size(0) + planes_odd,
1346           dims.input_size(1) + rows_odd,
1347           dims.input_size(2) + cols_odd,
1348       };
1349     } else {
1350       compatible_input_shape = {dims.batch_size, dims.in_depth,
1351                                 dims.input_size(0), dims.input_size(1),
1352                                 dims.input_size(2)};
1353     }
1354 
1355     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1356         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1357         << ", " << padding_planes << ")";
1358 
1359 #if GOOGLE_CUDA
1360     const bool compute_in_nhwc =
1361         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1362 #else
1363     // fast NDHWC implementation is a CUDA only feature
1364     const bool compute_in_nhwc = false;
1365 #endif
1366     const TensorFormat compute_data_format =
1367         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1368                                                          : FORMAT_NCHW;
1369 
1370     VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:"
1371             << " data_format=" << ToString(data_format_)
1372             << " compute_data_format=" << ToString(compute_data_format);
1373 
1374     constexpr auto kComputeInNHWC =
1375         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1376                         se::dnn::FilterLayout::kOutputYXInput);
1377     constexpr auto kComputeInNCHW =
1378         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1379                         se::dnn::FilterLayout::kOutputInputYX);
1380 
1381     se::dnn::DataLayout compute_data_layout;
1382     se::dnn::FilterLayout filter_layout;
1383 
1384     std::tie(compute_data_layout, filter_layout) =
1385         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1386 
1387     se::dnn::BatchDescriptor input_desc(3);
1388     input_desc.set_count(dims.batch_size)
1389         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
1390         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
1391         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
1392         .set_feature_map_count(dims.in_depth)
1393         .set_layout(compute_data_layout);
1394     se::dnn::BatchDescriptor output_desc(3);
1395     output_desc.set_count(dims.batch_size)
1396         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1397         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1398         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1399         .set_feature_map_count(dims.out_depth)
1400         .set_layout(compute_data_layout);
1401     se::dnn::FilterDescriptor filter_desc(3);
1402     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1403         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1404         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1405         .set_input_feature_map_count(filter_shape.dim_size(3))
1406         .set_output_feature_map_count(filter_shape.dim_size(4))
1407         .set_layout(filter_layout);
1408     se::dnn::ConvolutionDescriptor conv_desc(3);
1409     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1410         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1411         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1412         .set_filter_stride(DimIndex::X, dims.stride(2))
1413         .set_filter_stride(DimIndex::Y, dims.stride(1))
1414         .set_filter_stride(DimIndex::Z, dims.stride(0))
1415         .set_zero_padding(DimIndex::X, padding_cols / 2)
1416         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1417         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1418         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1419 
1420     // Shape: out, in, z, y, x.
1421     Tensor transformed_filter;
1422     auto dst_format =
1423         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1424     TensorShape dst_shape =
1425         dst_format == FORMAT_OIHW
1426             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1427                            dims.filter_size(0), dims.filter_size(1),
1428                            dims.filter_size(2)})
1429             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1430                            dims.filter_size(1), dims.filter_size(2),
1431                            filter_shape.dim_size(3)});
1432     OP_REQUIRES_OK(context,
1433                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1434                                           &transformed_filter));
1435 
1436     functor::TransformFilter<GPUDevice, T, int, 5>()(
1437         context->eigen_device<GPUDevice>(), dst_format,
1438         To32Bit(filter.tensor<T, 5>()),
1439         To32Bit(transformed_filter.tensor<T, 5>()));
1440 
1441     // Shape: batch, filters, z, y, x.
1442     Tensor transformed_out_backprop;
1443     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1444       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1445                                 dims.output_size(0), dims.output_size(1),
1446                                 dims.output_size(2)};
1447       if (dims.out_depth > 1) {
1448         OP_REQUIRES_OK(context, context->allocate_temp(
1449                                     DataTypeToEnum<T>::value, nchw_shape,
1450                                     &transformed_out_backprop));
1451         functor::NHWCToNCHW<GPUDevice, T, 5>()(
1452             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
1453             transformed_out_backprop.tensor<T, 5>());
1454       } else {
1455         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
1456       }
1457     } else {
1458       transformed_out_backprop = out_backprop;
1459     }
1460     // Shape: batch, filters, z, y, x.
1461     Tensor pre_transformed_in_backprop;
1462     OP_REQUIRES_OK(context,
1463                    context->allocate_temp(
1464                        DataTypeToEnum<T>::value,
1465                        ShapeFromFormat(compute_data_format,
1466                                        compatible_input_shape.dim_size(0),
1467                                        {{compatible_input_shape.dim_size(2),
1468                                          compatible_input_shape.dim_size(3),
1469                                          compatible_input_shape.dim_size(4)}},
1470                                        compatible_input_shape.dim_size(1)),
1471                        &pre_transformed_in_backprop));
1472 
1473     auto out_backprop_ptr =
1474         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
1475                        transformed_out_backprop.template flat<T>().size());
1476     auto filter_ptr =
1477         AsDeviceMemory(transformed_filter.template flat<T>().data(),
1478                        transformed_filter.template flat<T>().size());
1479     auto in_backprop_ptr =
1480         AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
1481                        pre_transformed_in_backprop.template flat<T>().size());
1482 
1483     static int64_t ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
1484         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
1485 
1486     const int device_id = stream->parent()->device_ordinal();
1487     // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer
1488     // the dtype from 2nd input, i.e., out_backprop.
1489     DataType dtype = context->input(2).dtype();
1490     const ConvParameters conv_parameters = {
1491         dims.batch_size,
1492         dims.in_depth,
1493         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
1494         compute_data_format,
1495         dims.out_depth,
1496         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
1497         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
1498         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
1499         {{padding_planes, padding_rows, padding_cols}},
1500         dtype,
1501         device_id,
1502         conv_desc.group_count()};
1503 
1504     using se::dnn::AlgorithmConfig;
1505     using se::dnn::AlgorithmDesc;
1506     using se::dnn::ProfileResult;
1507 #if TENSORFLOW_USE_ROCM
1508     // cudnn_use_autotune is applicable only the CUDA flow
1509     // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
1510     // if we do not have a cached algorithm_config for this conv_parameters
1511     cudnn_use_autotune_ = true;
1512 #endif
1513     AlgorithmConfig algorithm_config;
1514 
1515     if (cudnn_use_autotune_ && !AutotuneConv3dBwdData::GetInstance()->Find(
1516                                    conv_parameters, &algorithm_config)) {
1517       profiler::ScopedAnnotation trace("cudnn_autotuning");
1518       std::vector<std::unique_ptr<se::dnn::ConvolveExecutionPlan>> plans;
1519 #if GOOGLE_CUDA
1520       std::vector<AlgorithmDesc> algorithms;
1521       std::vector<AlgorithmConfig> configs;
1522       if (CudnnUseFrontend()) {
1523         OP_REQUIRES(context,
1524                     stream->parent()->GetConvolveExecutionPlans(
1525                         se::dnn::ConvolutionKind::BACKWARD_DATA,
1526                         se::dnn::ToDataType<T>::value, stream, input_desc,
1527                         filter_desc, output_desc, conv_desc, &plans),
1528                     errors::Unknown(
1529                         "Failed to get convolution execution plan. This is "
1530                         "probably because cuDNN failed to initialize, so try "
1531                         "looking to see if a warning log message was printed "
1532                         "above."));
1533         for (const auto& plan : plans) {
1534           configs.push_back(AlgorithmConfig(
1535               AlgorithmDesc{plan->getTag(), plan->get_raw_desc()},
1536               plan->getWorkspaceSize()));
1537         }
1538       } else {
1539         OP_REQUIRES(
1540             context,
1541             stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms),
1542             errors::Unknown(
1543                 "Failed to get convolution execution plan. This is probably "
1544                 "because cuDNN failed to initialize, so try looking to see if "
1545                 "a warning log message was printed above."));
1546         for (const auto& algorithm : algorithms) {
1547           configs.push_back(AlgorithmConfig(algorithm));
1548         }
1549       }
1550 
1551       se::TfAllocatorAdapter tf_allocator_adapter(
1552           context->device()->GetAllocator({}), stream);
1553       se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
1554                                         se::GpuAsmOpts());
1555       se::DeviceMemory<T> in_backprop_ptr_rz(
1556           WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
1557 
1558       std::vector<tensorflow::AutotuneResult> results;
1559       for (auto& profile_config : configs) {
1560         // TODO(zhengxq): profile each algorithm multiple times to better
1561         // accuracy.
1562         DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1563                                               context);
1564         se::RedzoneAllocator rz_scratch_allocator(
1565             stream, &tf_allocator_adapter, se::GpuAsmOpts(),
1566             /*memory_limit=*/ConvolveBackwardDataScratchSize);
1567         se::ScratchAllocator* allocator_used =
1568             !RedzoneCheckDisabled()
1569                 ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
1570                 : static_cast<se::ScratchAllocator*>(&scratch_allocator);
1571         ProfileResult profile_result;
1572 
1573         Status cudnn_launch_status;
1574         if (CudnnUseFrontend()) {
1575           cudnn_launch_status = stream->ConvolveBackwardDataWithExecutionPlan(
1576               filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1577               input_desc, &in_backprop_ptr_rz, allocator_used, profile_config,
1578               &profile_result);
1579         } else {
1580           cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1581               filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1582               input_desc, &in_backprop_ptr_rz, allocator_used, profile_config,
1583               &profile_result);
1584         }
1585 
1586         if (cudnn_launch_status.ok() && profile_result.is_valid()) {
1587           results.emplace_back();
1588           auto& result = results.back();
1589           if (CudnnUseFrontend()) {
1590             result.mutable_cuda_conv_plan()->set_exec_plan_id(
1591                 profile_config.algorithm()->exec_plan_id());
1592           } else {
1593             result.mutable_conv()->set_algorithm(
1594                 profile_config.algorithm()->algo_id());
1595             result.mutable_conv()->set_tensor_ops_enabled(
1596                 profile_config.algorithm()->tensor_ops_enabled());
1597           }
1598 
1599           result.set_scratch_bytes(
1600               !RedzoneCheckDisabled()
1601                   ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
1602                   : scratch_allocator.TotalByteSize());
1603           *result.mutable_run_time() = proto_utils::ToDurationProto(
1604               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1605 
1606           // TODO(george): they don't do results at all??
1607           CheckRedzones(rz_scratch_allocator, &result);
1608           CheckRedzones(rz_allocator, &result);
1609         } else {
1610           // When CuDNN frontend APIs are used, we need to make sure the
1611           // profiling results are one-to-one mapping of the "plans". So, we
1612           // insert dummy results when the excution fails.
1613           results.emplace_back();
1614           auto& result = results.back();
1615           result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
1616           result.mutable_failure()->set_msg(
1617               absl::StrCat("Profiling failure on CUDNN engine: ",
1618                            profile_config.algorithm()->exec_plan_id()));
1619         }
1620       }
1621 #elif TENSORFLOW_USE_ROCM
1622       DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1623                                             context);
1624       std::vector<ProfileResult> algorithms;
1625       CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
1626           se::dnn::ConvolutionKind::BACKWARD_DATA,
1627           se::dnn::ToDataType<T>::value, stream, input_desc, in_backprop_ptr,
1628           filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1629           &scratch_allocator, &algorithms));
1630       std::vector<tensorflow::AutotuneResult> results;
1631       for (auto miopen_algorithm : algorithms) {
1632         auto profile_algorithm = miopen_algorithm.algorithm();
1633         ProfileResult profile_result;
1634         auto miopen_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1635             filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1636             input_desc, &in_backprop_ptr, &scratch_allocator,
1637             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
1638             &profile_result);
1639         if (miopen_launch_status.ok()) {
1640           if (profile_result.is_valid()) {
1641             results.emplace_back();
1642             auto& result = results.back();
1643             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1644             result.mutable_conv()->set_tensor_ops_enabled(
1645                 profile_algorithm.tensor_ops_enabled());
1646             result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1647             *result.mutable_run_time() = proto_utils::ToDurationProto(
1648                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1649           }
1650         }
1651       }
1652 #endif
1653       LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
1654                              se::dnn::ToDataType<T>::value, in_backprop_ptr,
1655                              filter_ptr, out_backprop_ptr, input_desc,
1656                              filter_desc, output_desc, conv_desc,
1657                              stream->parent(), results);
1658       if (CudnnUseFrontend()) {
1659         OP_REQUIRES_OK(context, BestCudnnConvAlgorithm(results, &plans,
1660                                                        &algorithm_config));
1661       } else {
1662         OP_REQUIRES_OK(context, BestCudnnConvAlgorithm(results, nullptr,
1663                                                        &algorithm_config));
1664       }
1665       AutotuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
1666                                                    algorithm_config);
1667     }
1668 
1669     Status cudnn_launch_status;
1670     DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
1671                                           context);
1672     if (CudnnUseFrontend()) {
1673       if (algorithm_config.algorithm().has_value()) {
1674         VLOG(4) << "Conv3DBackpropInput Execution Plan: "
1675                 << algorithm_config.algorithm()->exec_plan_id();
1676       } else {
1677         VLOG(4) << "Convolution Autotune has been turned off";
1678       }
1679       cudnn_launch_status = stream->ConvolveBackwardDataWithExecutionPlan(
1680           filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1681           input_desc, &in_backprop_ptr, &scratch_allocator, algorithm_config,
1682           nullptr);
1683     } else {
1684       cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
1685           filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
1686           input_desc, &in_backprop_ptr, &scratch_allocator, algorithm_config,
1687           nullptr);
1688     }
1689 
1690     if (!cudnn_launch_status.ok()) {
1691       context->SetStatus(cudnn_launch_status);
1692     }
1693 
1694     if (rows_odd || cols_odd || planes_odd) {
1695       Tensor in_backprop_remove_padding;
1696       OP_REQUIRES_OK(
1697           context, context->allocate_temp(
1698                        DataTypeToEnum<T>::value,
1699                        ShapeFromFormat(compute_data_format, dims.batch_size,
1700                                        {{dims.input_size(0), dims.input_size(1),
1701                                          dims.input_size(2)}},
1702                                        dims.in_depth),
1703                        &in_backprop_remove_padding));
1704 
1705       // Remove the padding for odd spatial dimensions.
1706       functor::PadInput<GPUDevice, T, int, 5>()(
1707           context->eigen_device<GPUDevice>(),
1708           To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1709                       .tensor<T, 5>()),
1710           {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}},
1711           To32Bit(in_backprop_remove_padding.tensor<T, 5>()),
1712           compute_data_format, T{});
1713 
1714       pre_transformed_in_backprop = in_backprop_remove_padding;
1715     }
1716 
1717     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1718       auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1719       functor::NCHWToNHWC<GPUDevice, T, 5>()(
1720           context->eigen_device<GPUDevice>(),
1721           toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(),
1722           in_backprop->tensor<T, 5>());
1723     } else {
1724       *in_backprop = pre_transformed_in_backprop;
1725     }
1726   }
1727 
1728  private:
1729   std::vector<int32> dilation_;
1730   std::vector<int32> stride_;
1731   Padding padding_;
1732   TensorFormat data_format_;
1733   bool takes_shape_;
1734   bool cudnn_use_autotune_;
1735 };
1736 
1737 // A dummy type to group backward filter autotune results together.
1738 struct Conv3dBackwardFilterAutotuneGroup {
nametensorflow::Conv3dBackwardFilterAutotuneGroup1739   static string name() { return "Conv3dBwdFilter"; }
1740 };
1741 
1742 typedef AutotuneSingleton<Conv3dBackwardFilterAutotuneGroup, ConvParameters,
1743                           se::dnn::AlgorithmConfig>
1744     AutotuneConv3dBwdFilter;
1745 
1746 template <typename T>
1747 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
1748  public:
Conv3DBackpropFilterOp(OpKernelConstruction * context)1749   explicit Conv3DBackpropFilterOp(OpKernelConstruction* context)
1750       : OpKernel(context),
1751         data_format_(FORMAT_NHWC),
1752         takes_shape_(type_string().find("V2") != std::string::npos) {
1753     // data_format is only available in V2.
1754     if (takes_shape_) {
1755       string data_format;
1756       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1757       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1758                   errors::InvalidArgument("Invalid data format"));
1759     }
1760     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
1761     OP_REQUIRES(context, dilation_.size() == 5,
1762                 errors::InvalidArgument("Dilation rates field must "
1763                                         "specify 5 dimensions"));
1764     OP_REQUIRES(context,
1765                 (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
1766                  GetTensorDim(dilation_, data_format_, 'N') == 1),
1767                 errors::InvalidArgument(
1768                     "Current implementation does not yet support "
1769                     "dilation rates in the batch and depth dimensions."));
1770     OP_REQUIRES(
1771         context,
1772         (GetTensorDim(dilation_, data_format_, '0') > 0 &&
1773          GetTensorDim(dilation_, data_format_, '1') > 0 &&
1774          GetTensorDim(dilation_, data_format_, '2') > 0),
1775         errors::InvalidArgument("Dilated rates should be larger than 0."));
1776     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
1777     OP_REQUIRES(context, stride_.size() == 5,
1778                 errors::InvalidArgument("Sliding window strides field must "
1779                                         "specify 5 dimensions"));
1780     OP_REQUIRES(
1781         context,
1782         (GetTensorDim(stride_, data_format_, 'C') == 1 &&
1783          GetTensorDim(stride_, data_format_, 'N') == 1),
1784         errors::InvalidArgument("Current implementation does not yet support "
1785                                 "strides in the batch and depth dimensions."));
1786     OP_REQUIRES(
1787         context,
1788         (GetTensorDim(stride_, data_format_, '0') > 0 &&
1789          GetTensorDim(stride_, data_format_, '1') > 0 &&
1790          GetTensorDim(stride_, data_format_, '2') > 0),
1791         errors::InvalidArgument("Spatial strides should be larger than 0."));
1792     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1793     cudnn_use_autotune_ = CudnnUseAutotune();
1794   }
1795 
Compute(OpKernelContext * context)1796   void Compute(OpKernelContext* context) override {
1797     const Tensor& input = context->input(0);
1798     const TensorShape& input_shape = input.shape();
1799 
1800     const Tensor& out_backprop = context->input(2);
1801     const TensorShape& out_backprop_shape = out_backprop.shape();
1802 
1803     TensorShape filter_shape;
1804     if (takes_shape_) {
1805       const Tensor& filter_sizes = context->input(1);
1806       OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
1807     } else {
1808       filter_shape = context->input(1).shape();
1809     }
1810 
1811     ConvBackpropDimensions dims;
1812     OP_REQUIRES_OK(
1813         context,
1814         ConvBackpropComputeDimensionsV2(
1815             "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape,
1816             filter_shape, out_backprop_shape, dilation_, stride_, padding_,
1817             /*explicit_paddings=*/{}, data_format_, &dims));
1818 
1819     Tensor* filter_backprop;
1820     OP_REQUIRES_OK(context,
1821                    context->allocate_output(0, filter_shape, &filter_backprop));
1822 
1823     auto* stream = context->op_device_context()->stream();
1824     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
1825 
1826     bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
1827     if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
1828         dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
1829         dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
1830         dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
1831         dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
1832       const uint64 m = dims.in_depth;
1833       const uint64 k = dims.batch_size * dims.input_size(1) *
1834                        dims.input_size(2) * dims.input_size(0);
1835       const uint64 n = dims.out_depth;
1836 
1837       // The shape of output backprop is
1838       //   [batch, out_z, out_y, out_x, out_depth]
1839       // From cublas's perspective, it is: n x k
1840       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1841                                   out_backprop.template flat<T>().size());
1842 
1843       // The shape of input is:
1844       //   [batch, in_z, in_y, in_x, in_depth],
1845       // From cublas's perspective, it is: m x k
1846       auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
1847                                   input.template flat<T>().size());
1848 
1849       // The shape of the filter backprop is:
1850       //   [1, 1, 1, in_depth, out_depth]
1851       // From cublas's perspective, it is: n x m
1852       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1853                                   filter_backprop->template flat<T>().size());
1854 
1855       OP_REQUIRES_OK(context,
1856                      stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1857                                           se::blas::Transpose::kTranspose, n, m,
1858                                           k, a_ptr, n, b_ptr, m, &c_ptr, n));
1859       return;
1860     } else if (!is_grouped_convolution &&
1861                dims.filter_size(0) == dims.input_size(0) &&
1862                dims.filter_size(1) == dims.input_size(1) &&
1863                dims.filter_size(2) == dims.input_size(2) &&
1864                padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
1865       const uint64 m = dims.input_size(0) * dims.input_size(1) *
1866                        dims.input_size(2) * dims.in_depth;
1867       const uint64 k = dims.batch_size;
1868       const uint64 n = dims.out_depth;
1869 
1870       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
1871                                   input.template flat<T>().size());
1872       auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
1873                                   out_backprop.template flat<T>().size());
1874       auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
1875                                   filter_backprop->template flat<T>().size());
1876 
1877       OP_REQUIRES_OK(context,
1878                      stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
1879                                           se::blas::Transpose::kTranspose, n, m,
1880                                           k, b_ptr, n, a_ptr, m, &c_ptr, n));
1881       return;
1882     }
1883 
1884     int padding_planes = dims.SpatialPadding(padding_, 0);
1885     int padding_rows = dims.SpatialPadding(padding_, 1);
1886     int padding_cols = dims.SpatialPadding(padding_, 2);
1887     const bool planes_odd = (padding_planes % 2 != 0);
1888     const bool rows_odd = (padding_rows % 2 != 0);
1889     const bool cols_odd = (padding_cols % 2 != 0);
1890 
1891     Tensor compatible_input;
1892     if (rows_odd || cols_odd || planes_odd) {
1893       OP_REQUIRES_OK(context,
1894                      context->allocate_temp(
1895                          DataTypeToEnum<T>::value,
1896                          ShapeFromFormat(data_format_, dims.batch_size,
1897                                          {{dims.input_size(0) + planes_odd,
1898                                            dims.input_size(1) + rows_odd,
1899                                            dims.input_size(2) + cols_odd}},
1900                                          dims.in_depth),
1901                          &compatible_input));
1902       functor::PadInput<GPUDevice, T, int, 5>()(
1903           context->template eigen_device<GPUDevice>(),
1904           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
1905           {{planes_odd, rows_odd, cols_odd}},
1906           To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{});
1907     } else {
1908       compatible_input = input;
1909     }
1910 
1911     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
1912         << "Negative paddings: (" << padding_rows << ", " << padding_cols
1913         << ", " << padding_planes << ")";
1914 
1915 #if GOOGLE_CUDA
1916     const bool compute_in_nhwc =
1917         CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF;
1918 #else
1919     // fast NDHWC implementation is a CUDA only feature
1920     const bool compute_in_nhwc = false;
1921 #endif
1922     const TensorFormat compute_data_format =
1923         (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC
1924                                                          : FORMAT_NCHW;
1925 
1926     VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:"
1927             << " data_format=" << ToString(data_format_)
1928             << " compute_data_format=" << ToString(compute_data_format);
1929 
1930     constexpr auto kComputeInNHWC =
1931         std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
1932                         se::dnn::FilterLayout::kOutputYXInput);
1933     constexpr auto kComputeInNCHW =
1934         std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
1935                         se::dnn::FilterLayout::kOutputInputYX);
1936 
1937     se::dnn::DataLayout compute_data_layout;
1938     se::dnn::FilterLayout filter_layout;
1939 
1940     std::tie(compute_data_layout, filter_layout) =
1941         compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
1942 
1943     se::dnn::BatchDescriptor input_desc(3);
1944     input_desc.set_count(dims.batch_size)
1945         .set_spatial_dim(DimIndex::X,
1946                          GetTensorDim(compatible_input, data_format_, '2'))
1947         .set_spatial_dim(DimIndex::Y,
1948                          GetTensorDim(compatible_input, data_format_, '1'))
1949         .set_spatial_dim(DimIndex::Z,
1950                          GetTensorDim(compatible_input, data_format_, '0'))
1951         .set_feature_map_count(dims.in_depth)
1952         .set_layout(compute_data_layout);
1953     se::dnn::BatchDescriptor output_desc(3);
1954     output_desc.set_count(dims.batch_size)
1955         .set_spatial_dim(DimIndex::X, dims.output_size(2))
1956         .set_spatial_dim(DimIndex::Y, dims.output_size(1))
1957         .set_spatial_dim(DimIndex::Z, dims.output_size(0))
1958         .set_feature_map_count(dims.out_depth)
1959         .set_layout(compute_data_layout);
1960     se::dnn::FilterDescriptor filter_desc(3);
1961     filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
1962         .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
1963         .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
1964         .set_input_feature_map_count(filter_shape.dim_size(3))
1965         .set_output_feature_map_count(filter_shape.dim_size(4))
1966         .set_layout(filter_layout);
1967     se::dnn::ConvolutionDescriptor conv_desc(3);
1968     conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
1969         .set_dilation_rate(DimIndex::Y, dims.dilation(1))
1970         .set_dilation_rate(DimIndex::Z, dims.dilation(0))
1971         .set_filter_stride(DimIndex::X, dims.stride(2))
1972         .set_filter_stride(DimIndex::Y, dims.stride(1))
1973         .set_filter_stride(DimIndex::Z, dims.stride(0))
1974         .set_zero_padding(DimIndex::X, padding_cols / 2)
1975         .set_zero_padding(DimIndex::Y, padding_rows / 2)
1976         .set_zero_padding(DimIndex::Z, padding_planes / 2)
1977         .set_group_count(dims.in_depth / filter_shape.dim_size(3));
1978 
1979     Tensor pre_transformed_filter_backprop;
1980     auto dst_format =
1981         compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI;
1982     TensorShape dst_shape =
1983         dst_format == FORMAT_OIHW
1984             ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3),
1985                            dims.filter_size(0), dims.filter_size(1),
1986                            dims.filter_size(2)})
1987             : TensorShape({filter_shape.dim_size(4), dims.filter_size(0),
1988                            dims.filter_size(1), dims.filter_size(2),
1989                            filter_shape.dim_size(3)});
1990     OP_REQUIRES_OK(context,
1991                    context->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
1992                                           &pre_transformed_filter_backprop));
1993 
1994     Tensor transformed_out_backprop;
1995     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1996       VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW.";
1997       TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
1998                                 dims.output_size(0), dims.output_size(1),
1999                                 dims.output_size(2)};
2000       OP_REQUIRES_OK(
2001           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
2002                                           &transformed_out_backprop));
2003       if (dims.out_depth > 1) {
2004         functor::NHWCToNCHW<GPUDevice, T, 5>()(
2005             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
2006             transformed_out_backprop.tensor<T, 5>());
2007       } else {
2008         CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
2009       }
2010     } else {
2011       transformed_out_backprop = out_backprop;
2012     }
2013     Tensor transformed_input;
2014     if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
2015       VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW.";
2016       TensorShape nchw_shape = {
2017           dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
2018           compatible_input.dim_size(2), compatible_input.dim_size(3)};
2019       if (dims.in_depth > 1) {
2020         OP_REQUIRES_OK(context,
2021                        context->allocate_temp(DataTypeToEnum<T>::value,
2022                                               nchw_shape, &transformed_input));
2023         functor::NHWCToNCHW<GPUDevice, T, 5>()(
2024             context->eigen_device<GPUDevice>(),
2025             const_cast<const Tensor&>(compatible_input).tensor<T, 5>(),
2026             transformed_input.tensor<T, 5>());
2027       } else {
2028         CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
2029       }
2030     } else {
2031       transformed_input = compatible_input;
2032     }
2033 
2034     auto out_backprop_ptr =
2035         AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
2036                        transformed_out_backprop.template flat<T>().size());
2037     auto filter_backprop_ptr = AsDeviceMemory(
2038         pre_transformed_filter_backprop.template flat<T>().data(),
2039         pre_transformed_filter_backprop.template flat<T>().size());
2040     auto input_ptr =
2041         AsDeviceMemory(transformed_input.template flat<T>().data(),
2042                        transformed_input.template flat<T>().size());
2043 
2044     static int64_t ConvolveBackwardFilterScratchSize = GetDnnWorkspaceLimit(
2045         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32);  // 4GB by default
2046 
2047     const int device_id = stream->parent()->device_ordinal();
2048     DataType dtype = input.dtype();
2049     const ConvParameters conv_parameters = {
2050         dims.batch_size,
2051         dims.in_depth,
2052         {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
2053         compute_data_format,
2054         dims.out_depth,
2055         {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
2056         {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
2057         {{dims.stride(0), dims.stride(1), dims.stride(2)}},
2058         {{padding_planes, padding_rows, padding_cols}},
2059         dtype,
2060         device_id,
2061         conv_desc.group_count()};
2062 
2063     using se::dnn::AlgorithmConfig;
2064     using se::dnn::AlgorithmDesc;
2065     using se::dnn::ProfileResult;
2066 #if TENSORFLOW_USE_ROCM
2067     // cudnn_use_autotune is applicable only the CUDA flow
2068     // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
2069     // if we do not have a cached algorithm_config for this conv_parameters
2070     cudnn_use_autotune_ = true;
2071 #endif
2072 
2073     AlgorithmConfig algorithm_config;
2074 
2075     if (cudnn_use_autotune_ && !AutotuneConv3dBwdFilter::GetInstance()->Find(
2076                                    conv_parameters, &algorithm_config)) {
2077       std::vector<std::unique_ptr<se::dnn::ConvolveExecutionPlan>> plans;
2078 #if GOOGLE_CUDA
2079       std::vector<AlgorithmDesc> algorithms;
2080       std::vector<AlgorithmConfig> configs;
2081       if (CudnnUseFrontend()) {
2082         OP_REQUIRES(context,
2083                     stream->parent()->GetConvolveExecutionPlans(
2084                         se::dnn::ConvolutionKind::BACKWARD_FILTER,
2085                         se::dnn::ToDataType<T>::value, stream, input_desc,
2086                         filter_desc, output_desc, conv_desc, &plans),
2087                     errors::Unknown(
2088                         "Failed to get convolution execution plan. This is "
2089                         "probably because cuDNN failed to initialize, so try "
2090                         "looking to see if a warning log message was printed "
2091                         "above."));
2092         for (const auto& plan : plans) {
2093           configs.push_back(AlgorithmConfig(
2094               AlgorithmDesc{plan->getTag(), plan->get_raw_desc()},
2095               plan->getWorkspaceSize()));
2096         }
2097       } else {
2098         OP_REQUIRES(
2099             context,
2100             stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms),
2101             errors::Unknown(
2102                 "Failed to get convolution execution plan. This is probably "
2103                 "because cuDNN failed to initialize, so try looking to see if "
2104                 "a warning log message was printed above."));
2105         for (const auto& algorithm : algorithms) {
2106           configs.push_back(AlgorithmConfig(algorithm));
2107         }
2108       }
2109 
2110       std::vector<tensorflow::AutotuneResult> results;
2111       for (auto& profile_config : configs) {
2112         // TODO(zhengxq): profile each algorithm multiple times to better
2113         // accuracy.
2114         DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
2115                                               context);
2116         ProfileResult profile_result;
2117         Status cudnn_launch_status;
2118         if (CudnnUseFrontend()) {
2119           cudnn_launch_status = stream->ConvolveBackwardFilterWithExecutionPlan(
2120               input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2121               filter_desc, &filter_backprop_ptr, &scratch_allocator,
2122               profile_config, &profile_result);
2123         } else {
2124           cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
2125               input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2126               filter_desc, &filter_backprop_ptr, &scratch_allocator,
2127               profile_config, &profile_result);
2128         }
2129 
2130         if (cudnn_launch_status.ok() && profile_result.is_valid()) {
2131           results.emplace_back();
2132           auto& result = results.back();
2133           if (CudnnUseFrontend()) {
2134             result.mutable_cuda_conv_plan()->set_exec_plan_id(
2135                 profile_config.algorithm()->exec_plan_id());
2136           } else {
2137             result.mutable_conv()->set_algorithm(
2138                 profile_config.algorithm()->algo_id());
2139             result.mutable_conv()->set_tensor_ops_enabled(
2140                 profile_config.algorithm()->tensor_ops_enabled());
2141           }
2142 
2143           result.set_scratch_bytes(scratch_allocator.TotalByteSize());
2144           *result.mutable_run_time() = proto_utils::ToDurationProto(
2145               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
2146 
2147         } else if (CudnnUseFrontend()) {
2148           // When CuDNN frontend APIs are used, we need to make sure the
2149           // profiling results are one-to-one mapping of the "plans". So, we
2150           // insert dummy results when the excution fails.
2151           results.emplace_back();
2152           auto& result = results.back();
2153           result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
2154           result.mutable_failure()->set_msg(
2155               absl::StrCat("Profiling failure on CUDNN engine: ",
2156                            profile_config.algorithm()->exec_plan_id()));
2157         }
2158       }
2159 #elif TENSORFLOW_USE_ROCM
2160       DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
2161                                             context);
2162       std::vector<ProfileResult> algorithms;
2163       CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
2164           se::dnn::ConvolutionKind::BACKWARD_FILTER,
2165           se::dnn::ToDataType<T>::value, stream, input_desc, input_ptr,
2166           filter_desc, filter_backprop_ptr, output_desc, out_backprop_ptr,
2167           conv_desc, &scratch_allocator, &algorithms));
2168 
2169       std::vector<tensorflow::AutotuneResult> results;
2170       for (auto miopen_algorithm : algorithms) {
2171         auto profile_algorithm = miopen_algorithm.algorithm();
2172         ProfileResult profile_result;
2173         auto cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
2174             input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2175             filter_desc, &filter_backprop_ptr, &scratch_allocator,
2176             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
2177             &profile_result);
2178         if (cudnn_launch_status.ok()) {
2179           if (profile_result.is_valid()) {
2180             results.emplace_back();
2181             auto& result = results.back();
2182             result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
2183             result.mutable_conv()->set_tensor_ops_enabled(
2184                 profile_algorithm.tensor_ops_enabled());
2185             result.set_scratch_bytes(scratch_allocator.TotalByteSize());
2186             *result.mutable_run_time() = proto_utils::ToDurationProto(
2187                 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
2188           }
2189         }
2190       }
2191 #endif
2192       LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_FILTER,
2193                              se::dnn::ToDataType<T>::value, input_ptr,
2194                              filter_backprop_ptr, out_backprop_ptr, input_desc,
2195                              filter_desc, output_desc, conv_desc,
2196                              stream->parent(), results);
2197       if (CudnnUseFrontend()) {
2198         OP_REQUIRES_OK(context, BestCudnnConvAlgorithm(results, &plans,
2199                                                        &algorithm_config));
2200       } else {
2201         Status s = BestCudnnConvAlgorithm(results, nullptr, &algorithm_config);
2202 #if GOOGLE_CUDA
2203         if (s.code() == error::NOT_FOUND) {
2204           size_t version = cudnnGetVersion();
2205           // For cuDNN 8.0.3 and 8.0.4, no cudnnConvolutionBwdFilterAlgo_t will
2206           // work in certain cases. In such cases we improve the error message.
2207           // This is fixed in cuDNN 8.0.5. For more context, see:
2208           // https://github.com/tensorflow/tensorflow/issues/46589
2209           if (version == 8003 || version == 8004) {
2210             std::string version_str = (version == 8003 ? "8.0.3" : "8.0.4");
2211             s = errors::NotFound(
2212                 "No algorithm worked! Please try upgrading to cuDNN 8.0.5. You "
2213                 "are using cuDNN ",
2214                 version_str, ", which has a bug causing this error.");
2215           }
2216         }
2217 #endif
2218         OP_REQUIRES_OK(context, s);
2219       }
2220       AutotuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
2221                                                      algorithm_config);
2222     }
2223 
2224     Status cudnn_launch_status;
2225     DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
2226                                           context);
2227     if (CudnnUseFrontend()) {
2228       if (algorithm_config.algorithm().has_value()) {
2229         VLOG(4) << "Conv3DBackpropFilter Execution Plan: "
2230                 << algorithm_config.algorithm()->exec_plan_id();
2231       } else {
2232         VLOG(4) << "Convolution Autotune has been turned off";
2233       }
2234       cudnn_launch_status = stream->ConvolveBackwardFilterWithExecutionPlan(
2235           input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2236           filter_desc, &filter_backprop_ptr, &scratch_allocator,
2237           algorithm_config, nullptr);
2238     } else {
2239       cudnn_launch_status = stream->ConvolveBackwardFilterWithAlgorithm(
2240           input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
2241           filter_desc, &filter_backprop_ptr, &scratch_allocator,
2242           algorithm_config, nullptr);
2243     }
2244 
2245     if (!cudnn_launch_status.ok()) {
2246       context->SetStatus(cudnn_launch_status);
2247     }
2248 
2249     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
2250     functor::ReverseTransformFilter<GPUDevice, T, 5>()(
2251         context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format,
2252         toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(),
2253         filter_backprop->tensor<T, 5>());
2254   }
2255 
2256  private:
2257   std::vector<int32> dilation_;
2258   std::vector<int32> stride_;
2259   Padding padding_;
2260   TensorFormat data_format_;
2261   bool takes_shape_;
2262   bool cudnn_use_autotune_;
2263 };
2264 
2265 #define REGISTER_GPU_KERNEL(T)                                                \
2266   REGISTER_KERNEL_BUILDER(                                                    \
2267       Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"),  \
2268       Conv3DBackpropInputOp<GPUDevice, T>);                                   \
2269   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                       \
2270                               .Device(DEVICE_GPU)                             \
2271                               .TypeConstraint<T>("T")                         \
2272                               .HostMemory("input_sizes"),                     \
2273                           Conv3DBackpropInputOp<GPUDevice, T>);               \
2274   REGISTER_KERNEL_BUILDER(                                                    \
2275       Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2276       Conv3DBackpropFilterOp<GPUDevice, T>);                                  \
2277   REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
2278                               .Device(DEVICE_GPU)                             \
2279                               .TypeConstraint<T>("T")                         \
2280                               .HostMemory("filter_sizes"),                    \
2281                           Conv3DBackpropFilterOp<GPUDevice, T>);
2282 TF_CALL_half(REGISTER_GPU_KERNEL);
2283 TF_CALL_float(REGISTER_GPU_KERNEL);
2284 TF_CALL_double(REGISTER_GPU_KERNEL);
2285 #undef REGISTER_GPU_KERNEL
2286 
2287 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2288 
2289 }  // namespace tensorflow
2290