• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20 
21 #include "tensorflow/core/kernels/conv_grad_ops.h"
22 
23 #include <algorithm>
24 #include <vector>
25 
26 #include "absl/base/dynamic_annotations.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_2d.h"
34 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
35 #include "tensorflow/core/kernels/xsmm_conv2d.h"
36 #endif
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/array_slice.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/util/padding.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 #include "tensorflow/core/util/use_cudnn.h"
45 #include "tensorflow/core/util/work_sharder.h"
46 
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50 
51 #if GOOGLE_CUDA
52 #include "tensorflow/core/kernels/conv_ops_gpu.h"
53 #include "tensorflow/core/platform/stream_executor.h"
54 #include "tensorflow/core/protobuf/autotuning.pb.h"
55 #include "tensorflow/core/util/proto/proto_utils.h"
56 #endif  // GOOGLE_CUDA
57 
58 namespace {
59 
60 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
61 // order (height, width, depth), constructed from patches in 'col_data', which
62 // is required to be in storage order (out_height * out_width, filter_height,
63 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
64 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)65 void Col2im(const T* col_data, const int depth, const int height,
66             const int width, const int filter_h, const int filter_w,
67             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
68             const int stride_h, const int stride_w, T* im_data) {
69   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
70   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
71   int h_pad = -pad_t;
72   for (int h = 0; h < height_col; ++h) {
73     int w_pad = -pad_l;
74     for (int w = 0; w < width_col; ++w) {
75       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
76       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
77         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
78           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
79             // TODO(andydavis) Vectorize this loop (if compiler does not).
80             for (int i = 0; i < depth; ++i) {
81               im_patch_data[i] += col_data[i];
82             }
83           }
84           im_patch_data += depth;
85           col_data += depth;
86         }
87         // Jump over remaining number of depth.
88         im_patch_data += depth * (width - filter_w);
89       }
90       w_pad += stride_w;
91     }
92     h_pad += stride_h;
93   }
94 }
95 
96 }  // namespace
97 
98 namespace tensorflow {
99 
100 typedef Eigen::ThreadPoolDevice CPUDevice;
101 typedef Eigen::GpuDevice GPUDevice;
102 
103 // The fast versions using eigen computations directly. They are only enabled
104 // for CPU for now since nvcc times out when trying to compile them.
105 // TODO(yangke): enable them for GPUs when we have a faster compiler.
106 
107 template <typename T>
108 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DBackpropInputOp109   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
110                   const Tensor& out_backprop, const Tensor& filter,
111                   int row_dilation, int col_dilation, int row_stride,
112                   int col_stride, const Padding& padding,
113                   const std::vector<int64>& explicit_paddings,
114                   Tensor* in_backprop, TensorFormat data_format) {
115     const CPUDevice& d = ctx->eigen_device<CPUDevice>();
116     functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
117         d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
118         out_backprop.tensor<T, 4>(), row_stride, col_stride,
119         /*row_dilation=*/1, /*col_dilation=*/1);
120   }
121 };
122 
123 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
124 template <typename Device, class T>
125 struct LaunchXsmmBackwardInputConvolution {
operator ()tensorflow::LaunchXsmmBackwardInputConvolution126   bool operator()(OpKernelContext* context, const Device& d,
127                   typename TTypes<T, 4>::Tensor input_backward,
128                   typename TTypes<T, 4>::ConstTensor kernel,
129                   typename TTypes<T, 4>::ConstTensor output_backward,
130                   int input_rows, int input_cols, int row_stride,
131                   int col_stride, int pad_h, int pad_w,
132                   TensorFormat data_format) const {
133     return false;
134   }
135 };
136 
137 template <>
138 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
operator ()tensorflow::LaunchXsmmBackwardInputConvolution139   bool operator()(OpKernelContext* context, const CPUDevice& d,
140                   typename TTypes<float, 4>::Tensor input_backward,
141                   typename TTypes<float, 4>::ConstTensor kernel,
142                   typename TTypes<float, 4>::ConstTensor output_backward,
143                   int input_rows, int input_cols, int row_stride,
144                   int col_stride, int pad_h, int pad_w,
145                   TensorFormat data_format) const {
146     auto batch = input_backward.dimension(0);
147     auto in_depth = input_backward.dimension(3);
148     auto out_depth = output_backward.dimension(3);
149     auto filter_rows = kernel.dimension(0);
150     auto filter_cols = kernel.dimension(1);
151     auto num_threads =
152         context->device()->tensorflow_cpu_worker_threads()->num_threads;
153     // See libxsmm_dnn.h for this struct definition.
154     libxsmm_dnn_conv_desc desc;
155     desc.N = batch;
156     desc.C = in_depth;
157     desc.H = input_rows;
158     desc.W = input_cols;
159     desc.K = out_depth;
160     desc.R = filter_rows;
161     desc.S = filter_cols;
162     desc.u = row_stride;
163     desc.v = col_stride;
164     desc.pad_h = pad_h;
165     desc.pad_w = pad_w;
166     desc.pad_h_in = 0;
167     desc.pad_w_in = 0;
168     desc.pad_h_out = 0;
169     desc.pad_w_out = 0;
170     desc.threads = num_threads;
171     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
172     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
173     desc.filter_format =
174         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;  // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
175     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
176     desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
177     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
178 
179     auto input_ptr = input_backward.data();
180     auto filter_ptr = kernel.data();
181     auto output_ptr = output_backward.data();
182 
183     bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
184         context, desc, input_ptr, filter_ptr, output_ptr);
185     return success;
186   }
187 };
188 #endif
189 
190 template <typename T>
191 struct Conv2DCustomBackpropInputMatMulFunctor {
192   using MatrixMap = Eigen::Map<
193       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
194   using ConstMatrixMap = Eigen::Map<
195       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
196 
operator ()tensorflow::Conv2DCustomBackpropInputMatMulFunctor197   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
198                   const int filter_total_size, const int output_image_size,
199                   const int dims_out_depth, T* im2col_buf) {
200     // Compute gradient into 'im2col_buf'.
201     MatrixMap C(im2col_buf, output_image_size, filter_total_size);
202 
203     ConstMatrixMap A(out_data, output_image_size, dims_out_depth);
204     ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth);
205 
206     C.noalias() = A * B.transpose();
207   }
208 };
209 
210 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
211 template <>
212 struct Conv2DCustomBackpropInputMatMulFunctor<float> {
213   using T = float;
214 
operator ()tensorflow::Conv2DCustomBackpropInputMatMulFunctor215   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
216                   const int filter_total_size, const int output_image_size,
217                   const int dims_out_depth, T* im2col_buf) {
218     // Inputs are in RowMajor order, we "cheat" by swapping the LHS and RHS:
219     //   RowMajor: C   = A   * B
220     //   ColMajor: C^T = B^T * A^T
221     //
222     // Dimension names:
223     //   out_image_size    -> ois
224     //   filter_total_size -> fts
225     //   dims_out_depth    -> dod
226     //
227     // RowMajor:
228     //   im2col      = out_data    * filter_data^T
229     //   [ois x fts] = [ois x dod] * [fts x dod]^T
230     //
231     // ColMajor:
232     //   im2col^T    = filter_data *  out_data^T
233     //   [fts x ois] = [fts x dod] * [dod x ois]*
234 
235     const int m = filter_total_size;
236     const int n = output_image_size;
237     const int k = dims_out_depth;  // contraction dim
238 
239     const char transposeA = 'T';  // sgemm(A) == filter_data
240     const char transposeB = 'N';  // sgemm(B) == out_data
241 
242     const int ldA = dims_out_depth;
243     const int ldB = dims_out_depth;
244     const int ldC = filter_total_size;
245 
246     const float alpha = 1.0;
247     const float beta = 0.0;
248 
249     // mkldnn_sgemm code can't be instrumented with msan.
250     ANNOTATE_MEMORY_IS_INITIALIZED(
251         im2col_buf, filter_total_size * output_image_size * sizeof(T));
252 
253     mkldnn_status_t st =
254         mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, &alpha, filter_data,
255                      &ldA, out_data, &ldB, &beta, im2col_buf, &ldC);
256 
257     OP_REQUIRES(
258         ctx, st == 0,
259         errors::Internal("Failed to call mkldnn_sgemm. Error code: ", st));
260   }
261 };
262 #endif
263 
264 // Based on implementation written by Yangqing Jia (jiayq).
265 template <typename Device, class T>
266 class Conv2DCustomBackpropInputOp : public OpKernel {
267  public:
Conv2DCustomBackpropInputOp(OpKernelConstruction * context)268   explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
269       : OpKernel(context) {
270     string data_format;
271     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
272     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
273                 errors::InvalidArgument("Invalid data format"));
274     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
275                 errors::InvalidArgument(
276                     "Conv2DCustomBackpropInputOp only supports NHWC."));
277     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
278     OP_REQUIRES(context, strides_.size() == 4,
279                 errors::InvalidArgument("Sliding window strides field must "
280                                         "specify 4 dimensions"));
281     OP_REQUIRES(
282         context, (strides_[0] == 1 && strides_[3] == 1),
283         errors::InvalidArgument("Current implementation does not yet support "
284                                 "strides in the batch and depth dimensions."));
285     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
286                 errors::InvalidArgument(
287                     "Row and column strides should be larger than 0."));
288     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
289     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
290     OP_REQUIRES(context, dilations_.size() == 4,
291                 errors::InvalidArgument("Sliding window dilations field must "
292                                         "specify 4 dimensions"));
293     OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
294                 errors::InvalidArgument(
295                     "Current implementation does not yet support "
296                     "dilations in the batch and depth dimensions."));
297     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
298     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
299                 errors::InvalidArgument(
300                     "Current libxsmm and customized CPU implementations do "
301                     "not yet support dilation rates larger than 1."));
302     OP_REQUIRES(
303         context, padding_ != Padding::EXPLICIT,
304         errors::Unimplemented("Current CPU implementation does not support "
305                               "EXPLICIT padding yet."));
306     std::vector<int64> explicit_paddings;
307     OP_REQUIRES_OK(context,
308                    context->GetAttr("explicit_paddings", &explicit_paddings));
309     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
310                                               /*num_dims=*/4, data_format_));
311   }
312 
Compute(OpKernelContext * context)313   void Compute(OpKernelContext* context) override {
314     const Tensor& input_sizes = context->input(0);
315     const Tensor& filter = context->input(1);
316     const Tensor& out_backprop = context->input(2);
317     OP_REQUIRES(
318         context, TensorShapeUtils::IsVector(input_sizes.shape()),
319         errors::InvalidArgument(
320             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
321             input_sizes.dims()));
322     TensorShape input_shape;
323     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
324                                 input_sizes.vec<int32>(), &input_shape));
325 
326     ConvBackpropDimensions dims;
327     OP_REQUIRES_OK(context,
328                    ConvBackpropComputeDimensions(
329                        "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
330                        input_shape, filter.shape(), out_backprop.shape(),
331                        strides_, padding_, data_format_, &dims));
332 
333     Tensor* in_backprop = nullptr;
334     OP_REQUIRES_OK(context,
335                    context->allocate_output(0, input_shape, &in_backprop));
336 
337     // If there is nothing to compute, return.
338     if (input_shape.num_elements() == 0) {
339       return;
340     }
341 
342 // TODO(andydavis) Consider moving code shared with
343 // Conv2DCustomBackpropFilterOp into a shared helper function.
344 #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
345     defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
346     int64 pad_top, pad_bottom;
347     int64 pad_left, pad_right;
348     OP_REQUIRES_OK(
349         context,
350         GetWindowedOutputSizeVerbose(
351             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
352             dims.spatial_dims[0].stride, padding_,
353             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
354     OP_REQUIRES_OK(
355         context,
356         GetWindowedOutputSizeVerbose(
357             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
358             dims.spatial_dims[1].stride, padding_,
359             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
360 
361     if (pad_left == pad_right && pad_top == pad_bottom) {
362       if (LaunchXsmmBackwardInputConvolution<Device, T>()(
363               context, context->eigen_device<Device>(),
364               in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
365               out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
366               dims.spatial_dims[1].input_size,
367               static_cast<int>(dims.spatial_dims[0].stride),
368               static_cast<int>(dims.spatial_dims[1].stride),
369               static_cast<int>(pad_top), static_cast<int>(pad_left),
370               data_format_)) {
371         return;
372       }
373     }
374 #else
375     int64 pad_top, pad_bottom;
376     int64 pad_left, pad_right;
377 #endif
378     OP_REQUIRES_OK(
379         context,
380         GetWindowedOutputSizeVerbose(
381             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
382             dims.spatial_dims[0].stride, padding_,
383             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
384     OP_REQUIRES_OK(
385         context,
386         GetWindowedOutputSizeVerbose(
387             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
388             dims.spatial_dims[1].stride, padding_,
389             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
390 
391     // The total dimension size of each kernel.
392     const int filter_total_size = dims.spatial_dims[0].filter_size *
393                                   dims.spatial_dims[1].filter_size *
394                                   dims.in_depth;
395     // The output image size is the spatial size of the output.
396     const int output_image_size =
397         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
398 
399     // TODO(andydavis) Get L2/L3 cache sizes from device.
400     const size_t l2_cache_size = 256LL << 10;
401     const size_t l3_cache_size = 30LL << 20;
402 
403     // Use L3 cache size as target working set size.
404     const size_t target_working_set_size = l3_cache_size / sizeof(T);
405 
406     // Calculate size of matrices involved in MatMul: C = A x B.
407     const size_t size_A = output_image_size * dims.out_depth;
408 
409     const size_t size_B = filter_total_size * dims.out_depth;
410 
411     const size_t size_C = output_image_size * filter_total_size;
412 
413     const size_t work_unit_size = size_A + size_B + size_C;
414 
415     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
416 
417     // Calculate per-thread work unit size.
418     const size_t thread_work_unit_size =
419         work_unit_size / worker_threads.num_threads;
420 
421     // Set minimum per-thread work unit size to size of L2 cache.
422     const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
423 
424     // Use parallel tensor contractions if there is no batching, or if the
425     // minimum per-thread work unit size threshold has been exceeded.
426     // Otherwise, revert to multiple single-threaded matmul ops running in
427     // parallel to keep all threads busy.
428     // TODO(andydavis) Explore alternatives to branching the code in this way
429     // (i.e. run multiple, parallel tensor contractions in another thread pool).
430     const bool use_parallel_contraction =
431         dims.batch_size == 1 ||
432         thread_work_unit_size >= min_thread_work_unit_size;
433 
434     const size_t shard_size =
435         use_parallel_contraction
436             ? 1
437             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
438 
439     Tensor col_buffer;
440     OP_REQUIRES_OK(context,
441                    context->allocate_temp(
442                        DataTypeToEnum<T>::value,
443                        TensorShape({static_cast<int64>(shard_size),
444                                     static_cast<int64>(output_image_size),
445                                     static_cast<int64>(filter_total_size)}),
446                        &col_buffer));
447 
448     // The input offset corresponding to a single input image.
449     const int input_offset = dims.spatial_dims[0].input_size *
450                              dims.spatial_dims[1].input_size * dims.in_depth;
451     // The output offset corresponding to a single output image.
452     const int output_offset = dims.spatial_dims[0].output_size *
453                               dims.spatial_dims[1].output_size * dims.out_depth;
454 
455     const T* filter_data = filter.template flat<T>().data();
456     T* col_buffer_data = col_buffer.template flat<T>().data();
457     const T* out_backprop_data = out_backprop.template flat<T>().data();
458 
459     auto in_backprop_flat = in_backprop->template flat<T>();
460     T* input_backprop_data = in_backprop_flat.data();
461     in_backprop_flat.device(context->eigen_device<Device>()) =
462         in_backprop_flat.constant(T(0));
463 
464     if (use_parallel_contraction) {
465       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
466                                Eigen::Unaligned>
467           TensorMap;
468       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
469                                Eigen::Unaligned>
470           ConstTensorMap;
471 
472       // Initialize contraction dims (we need to transpose 'B' below).
473       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
474       contract_dims[0].first = 1;
475       contract_dims[0].second = 1;
476 
477       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
478         // Compute gradient into col_buffer.
479         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
480 
481         ConstTensorMap A(out_backprop_data + output_offset * image_id,
482                          output_image_size, dims.out_depth);
483         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
484 
485         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
486 
487         Col2im<T>(
488             col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
489             dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
490             dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
491             pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
492             input_backprop_data);
493 
494         input_backprop_data += input_offset;
495       }
496     } else {
497       for (int image_id = 0; image_id < dims.batch_size;
498            image_id += shard_size) {
499         const int shard_limit =
500             std::min(static_cast<int>(shard_size),
501                      static_cast<int>(dims.batch_size) - image_id);
502 
503         auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom,
504                       &pad_right, &output_image_size, &filter_total_size,
505                       &input_backprop_data, &col_buffer_data,
506                       &out_backprop_data, &filter_data, &input_offset,
507                       &output_offset, &size_C](int64 start, int64 limit) {
508           for (int shard_id = start; shard_id < limit; ++shard_id) {
509             T* im2col_buf = col_buffer_data + shard_id * size_C;
510             T* input_data = input_backprop_data + shard_id * input_offset;
511             const T* out_data = out_backprop_data + shard_id * output_offset;
512 
513             Conv2DCustomBackpropInputMatMulFunctor<T>()(
514                 context, out_data, filter_data, filter_total_size,
515                 output_image_size, dims.out_depth, im2col_buf);
516 
517             Col2im<T>(im2col_buf, dims.in_depth,
518                       dims.spatial_dims[0].input_size,
519                       dims.spatial_dims[1].input_size,
520                       dims.spatial_dims[0].filter_size,
521                       dims.spatial_dims[1].filter_size, pad_top, pad_left,
522                       pad_bottom, pad_right, dims.spatial_dims[0].stride,
523                       dims.spatial_dims[1].stride, input_data);
524           }
525         };
526         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
527               work_unit_size, shard);
528 
529         input_backprop_data += input_offset * shard_limit;
530         out_backprop_data += output_offset * shard_limit;
531       }
532     }
533   }
534 
535  private:
536   std::vector<int32> dilations_;
537   std::vector<int32> strides_;
538   Padding padding_;
539   TensorFormat data_format_;
540 
541   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
542 };
543 
544 #define REGISTER_CPU_KERNELS(T)                                              \
545   REGISTER_KERNEL_BUILDER(                                                   \
546       Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
547       Conv2DCustomBackpropInputOp<CPUDevice, T>);                            \
548   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
549                               .Device(DEVICE_CPU)                            \
550                               .Label("custom")                               \
551                               .TypeConstraint<T>("T"),                       \
552                           Conv2DCustomBackpropInputOp<CPUDevice, T>);
553 
554 TF_CALL_half(REGISTER_CPU_KERNELS);
555 TF_CALL_float(REGISTER_CPU_KERNELS);
556 TF_CALL_double(REGISTER_CPU_KERNELS);
557 #undef REGISTER_CPU_KERNELS
558 
559 // To be used inside depthwise_conv_grad_op.cc.
560 template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
561 template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
562 template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
563 
564 // GPU definitions.
565 #if GOOGLE_CUDA
566 // The slow version (but compiles for GPU)
567 
568 // A dummy type to group forward backward data autotune results together.
569 struct ConvBackwardDataAutoTuneGroup {
nametensorflow::ConvBackwardDataAutoTuneGroup570   static string name() { return "ConvBwdData"; }
571 };
572 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
573                           se::dnn::AlgorithmConfig>
574     AutoTuneConvBwdData;
575 
576 // Backprop for input.
577 template <typename Device, class T>
578 class Conv2DSlowBackpropInputOp : public OpKernel {
579  public:
Conv2DSlowBackpropInputOp(OpKernelConstruction * context)580   explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context)
581       : OpKernel(context) {
582     string data_format;
583     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
584     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
585                 errors::InvalidArgument("Invalid data format"));
586     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
587     OP_REQUIRES(context, strides_.size() == 4,
588                 errors::InvalidArgument("Sliding window strides field must "
589                                         "specify 4 dimensions"));
590     int stride_n = GetTensorDim(strides_, data_format_, 'N');
591     int stride_c = GetTensorDim(strides_, data_format_, 'C');
592     int stride_h = GetTensorDim(strides_, data_format_, 'H');
593     int stride_w = GetTensorDim(strides_, data_format_, 'W');
594     OP_REQUIRES(
595         context, (stride_n == 1 && stride_c == 1),
596         errors::InvalidArgument("Current implementation does not yet support "
597                                 "strides in the batch and depth dimensions."));
598     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
599                 errors::InvalidArgument(
600                     "Row and column strides should be larger than 0."));
601     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
602     OP_REQUIRES(context, dilations_.size() == 4,
603                 errors::InvalidArgument("Sliding window dilations field must "
604                                         "specify 4 dimensions"));
605     int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
606     int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
607     int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
608     int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
609     OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
610                 errors::InvalidArgument(
611                     "Current implementation does not yet support "
612                     "dilations in the batch and depth dimensions."));
613     OP_REQUIRES(
614         context, dilation_h > 0 && dilation_w > 0,
615         errors::InvalidArgument("Dilated rates should be larger than 0."));
616     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
617     use_cudnn_ &= CanUseCudnn();
618     cudnn_use_autotune_ = CudnnUseAutotune();
619     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
620     if (!std::is_same<Device, GPUDevice>::value) {
621       OP_REQUIRES(
622           context, padding_ != Padding::EXPLICIT,
623           errors::Unimplemented("Current CPU implementation does not support "
624                                 "EXPLICIT padding yet."));
625     }
626     OP_REQUIRES_OK(context,
627                    context->GetAttr("explicit_paddings", &explicit_paddings_));
628     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
629                                               /*num_dims=*/4, data_format_));
630   }
631 
Compute(OpKernelContext * context)632   void Compute(OpKernelContext* context) override {
633     const Tensor& input_sizes = context->input(0);
634     const Tensor& filter = context->input(1);
635     const Tensor& out_backprop = context->input(2);
636     OP_REQUIRES(
637         context, TensorShapeUtils::IsVector(input_sizes.shape()),
638         errors::InvalidArgument(
639             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
640             input_sizes.dims()));
641     TensorShape input_shape;
642     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
643                                 input_sizes.vec<int32>(), &input_shape));
644 
645     Tensor* in_backprop = nullptr;
646     OP_REQUIRES_OK(context,
647                    context->allocate_output(0, input_shape, &in_backprop));
648 
649     // If there is nothing to compute, return.
650     if (input_shape.num_elements() == 0) {
651       return;
652     }
653 
654     // For now we take the stride from the second and third dimensions only (we
655     // do not support striding on the batch or depth dimension).
656     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
657     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
658     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
659     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
660 
661     launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
662               dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
663               explicit_paddings_, in_backprop, data_format_);
664   }
665 
666  private:
667   std::vector<int32> dilations_;
668   std::vector<int32> strides_;
669   Padding padding_;
670   std::vector<int64> explicit_paddings_;
671   bool use_cudnn_;
672   TensorFormat data_format_;
673   LaunchConv2DBackpropInputOp<Device, T> launcher_;
674   bool cudnn_use_autotune_;
675 
676   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
677 };
678 
679 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & out_backprop,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * in_backprop,TensorFormat data_format)680 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
681     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
682     const Tensor& out_backprop, const Tensor& filter, int row_dilation,
683     int col_dilation, int row_stride, int col_stride, const Padding& padding,
684     const std::vector<int64>& explicit_paddings, Tensor* in_backprop,
685     TensorFormat data_format) {
686   using se::dnn::AlgorithmConfig;
687   using se::dnn::AlgorithmDesc;
688   using se::dnn::ProfileResult;
689 
690   std::vector<int32> strides(4, 1);
691   std::vector<int32> dilations(4, 1);
692   auto input_h = GetTensorDimIndex(data_format, 'H');
693   auto input_w = GetTensorDimIndex(data_format, 'W');
694   strides[input_h] = row_stride;
695   strides[input_w] = col_stride;
696   dilations[input_h] = row_dilation;
697   dilations[input_w] = col_dilation;
698   TensorShape input_shape = in_backprop->shape();
699 
700   const TensorShape& filter_shape = filter.shape();
701   ConvBackpropDimensions dims;
702   OP_REQUIRES_OK(
703       ctx, ConvBackpropComputeDimensionsV2(
704                "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape,
705                filter_shape, out_backprop.shape(), dilations, strides, padding,
706                explicit_paddings, data_format, &dims));
707 
708   int64 padding_top = -1, padding_bottom = -1;
709   int64 padding_left = -1, padding_right = -1;
710   if (padding == EXPLICIT) {
711     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
712                              &padding_bottom);
713     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
714                              &padding_right);
715   }
716   int64 expected_out_rows, expected_out_cols;
717   // The function is guaranteed to succeed because we checked the output and
718   // padding was valid earlier.
719   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
720       dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
721       row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
722       &padding_bottom));
723   DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
724   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
725       dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
726       col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
727       &padding_right));
728   DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
729 
730   auto* stream = ctx->op_device_context()->stream();
731   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
732 
733   if (!use_cudnn) {
734     ctx->SetStatus(errors::Unimplemented(
735         "Conv2DBackpropInput for GPU is not currently supported "
736         "without cudnn"));
737     return;
738   }
739 
740   // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
741   // input depth, it's a depthwise convolution. More generally, if the filter
742   // in-depth divides but is smaller than the input depth, it is a grouped
743   // convolution.
744   bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
745   if (dims.spatial_dims[0].filter_size == 1 &&
746       dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
747       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
748       data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
749     // 1x1 filter, so call cublas directly.
750     const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
751                      dims.spatial_dims[1].input_size;
752     const uint64 k = dims.out_depth;
753     const uint64 n = dims.in_depth;
754 
755     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
756                                 out_backprop.template flat<T>().size());
757     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
758                                 filter.template flat<T>().size());
759     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
760                                 in_backprop->template flat<T>().size());
761 
762     auto transpose = se::blas::Transpose::kTranspose;
763     auto no_transpose = se::blas::Transpose::kNoTranspose;
764 
765     bool blas_launch_status =
766         stream
767             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
768                            a_ptr, k, 0.0f, &c_ptr, n)
769             .ok();
770     if (!blas_launch_status) {
771       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
772                                       ", n=", n, ", k=", k));
773     }
774     return;
775   } else if (dims.spatial_dims[0].filter_size ==
776                  dims.spatial_dims[0].input_size &&
777              dims.spatial_dims[1].filter_size ==
778                  dims.spatial_dims[1].input_size &&
779              !is_grouped_convolution && padding == VALID &&
780              data_format == FORMAT_NHWC) {
781     // The input data and filter have the same height/width, and we are not
782     // using grouped convolution, so call cublas directly.
783     const uint64 m = dims.batch_size;
784     const uint64 k = dims.out_depth;
785     const uint64 n = dims.spatial_dims[0].input_size *
786                      dims.spatial_dims[1].input_size * dims.in_depth;
787 
788     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
789                                 out_backprop.template flat<T>().size());
790     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
791                                 filter.template flat<T>().size());
792     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
793                                 in_backprop->template flat<T>().size());
794 
795     auto transpose = se::blas::Transpose::kTranspose;
796     auto no_transpose = se::blas::Transpose::kNoTranspose;
797 
798     bool blas_launch_status =
799         stream
800             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
801                            a_ptr, k, 0.0f, &c_ptr, n)
802             .ok();
803     if (!blas_launch_status) {
804       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
805                                       ", n=", n, ", k=", k));
806     }
807     return;
808   }
809 
810   const int64 common_padding_rows = std::min(padding_top, padding_bottom);
811   const int64 common_padding_cols = std::min(padding_left, padding_right);
812   TensorShape compatible_input_shape;
813   if (padding_top != padding_bottom || padding_left != padding_right) {
814     // Pad the input in the same way we did during the forward pass, so that
815     // cuDNN receives the same input during the backward pass function as it did
816     // during the forward pass function.
817     const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
818     const int64 padding_cols_diff = std::abs(padding_right - padding_left);
819     const int64 new_in_rows =
820         dims.spatial_dims[0].input_size + padding_rows_diff;
821     const int64 new_in_cols =
822         dims.spatial_dims[1].input_size + padding_cols_diff;
823     compatible_input_shape = ShapeFromFormat(
824         data_format, dims.batch_size, new_in_rows, new_in_cols, dims.in_depth);
825   } else {
826     compatible_input_shape = input_shape;
827   }
828 
829   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
830       << "Negative row or col paddings: (" << common_padding_rows << ", "
831       << common_padding_cols << ")";
832   se::dnn::BatchDescriptor input_desc;
833   input_desc.set_count(dims.batch_size)
834       .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
835       .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
836       .set_feature_map_count(dims.in_depth)
837       .set_layout(se::dnn::DataLayout::kBatchDepthYX);
838   se::dnn::BatchDescriptor output_desc;
839   output_desc.set_count(dims.batch_size)
840       .set_height(dims.spatial_dims[0].output_size)
841       .set_width(dims.spatial_dims[1].output_size)
842       .set_feature_map_count(dims.out_depth)
843       .set_layout(se::dnn::DataLayout::kBatchDepthYX);
844   se::dnn::FilterDescriptor filter_desc;
845   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
846       .set_input_filter_width(dims.spatial_dims[1].filter_size)
847       .set_input_feature_map_count(filter_shape.dim_size(2))
848       .set_output_feature_map_count(filter_shape.dim_size(3));
849   se::dnn::ConvolutionDescriptor conv_desc;
850   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
851       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
852       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
853       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
854       .set_zero_padding_height(common_padding_rows)
855       .set_zero_padding_width(common_padding_cols)
856       .set_group_count(dims.in_depth / filter_shape.dim_size(2));
857 
858   // NOTE(keveman):
859   // cuDNN only supports the following layouts :
860   // Input  : B x D x R x C
861   // Filter : OD x ID x R x C
862   // Whereas, we have
863   // Input  : B x R x C x D
864   // Filter : R x C x ID x OD
865   // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
866   // The first TransformDepth performs
867   // (B x R x C x D) => (B x D x R x C).
868   // Since the tensor returned from cuDNN is B x D x R x C also,
869   // the second TransformDepth performs
870   // (B x D x R x C) => (B x R x C x D).
871   Tensor transformed_filter;
872   OP_REQUIRES_OK(
873       ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
874                               TensorShape({dims.out_depth, dims.in_depth,
875                                            dims.spatial_dims[0].filter_size,
876                                            dims.spatial_dims[1].filter_size}),
877                               &transformed_filter));
878 
879   functor::TransformFilter<GPUDevice, T, int, 4>()(
880       ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
881       To32Bit(filter.tensor<T, 4>()),
882       To32Bit(transformed_filter.tensor<T, 4>()));
883 
884   Tensor transformed_out_backprop;
885   if (data_format == FORMAT_NHWC) {
886     TensorShape nchw_shape = ShapeFromFormat(
887         FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
888         dims.spatial_dims[1].output_size, dims.out_depth);
889     if (dims.out_depth > 1) {
890       OP_REQUIRES_OK(ctx,
891                      ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
892                                         &transformed_out_backprop));
893       functor::NHWCToNCHW<GPUDevice, T, 4>()(
894           ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
895           transformed_out_backprop.tensor<T, 4>());
896     } else {
897       // If depth <= 1, then just reshape.
898       CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
899     }
900   } else {
901     transformed_out_backprop = out_backprop;
902   }
903 
904   Tensor pre_transformed_in_backprop;
905   OP_REQUIRES_OK(
906       ctx, ctx->allocate_temp(
907                DataTypeToEnum<T>::value,
908                ShapeFromFormat(
909                    FORMAT_NCHW,
910                    GetTensorDim(compatible_input_shape, data_format, 'N'),
911                    GetTensorDim(compatible_input_shape, data_format, 'H'),
912                    GetTensorDim(compatible_input_shape, data_format, 'W'),
913                    GetTensorDim(compatible_input_shape, data_format, 'C')),
914                &pre_transformed_in_backprop));
915 
916   auto out_backprop_ptr =
917       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
918                      transformed_out_backprop.template flat<T>().size());
919   auto filter_ptr =
920       AsDeviceMemory(transformed_filter.template flat<T>().data(),
921                      transformed_filter.template flat<T>().size());
922   auto in_backprop_ptr =
923       AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
924                      pre_transformed_in_backprop.template flat<T>().size());
925 
926   static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
927       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB by default
928   );
929   DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
930   int device_id = stream->parent()->device_ordinal();
931   DataType dtype = out_backprop.dtype();
932   ConvParameters conv_parameters = {
933       dims.batch_size,                     // batch
934       dims.in_depth,                       // in_depths
935       {{input_desc.height(),               // in_rows
936         input_desc.width()}},              // in_cols
937       FORMAT_NCHW,                         // compute_data_format
938       dims.out_depth,                      // out_depths
939       {{dims.spatial_dims[0].filter_size,  // filter_rows
940         dims.spatial_dims[1].filter_size,  // filter_cols
941         filter_shape.dim_size(2)}},        // filter_depths
942       {{dims.spatial_dims[0].dilation,     // dilation_rows
943         dims.spatial_dims[1].dilation}},   // dilation_cols
944       {{dims.spatial_dims[0].stride,       // stride_rows
945         dims.spatial_dims[1].stride}},     // stride_cols
946       {{common_padding_rows,               // padding_rows
947         common_padding_cols}},             // padding_cols
948       dtype,                               // tensor data type
949       device_id,                           // device_id
950   };
951   AlgorithmConfig algorithm_config;
952   if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
953                                 conv_parameters, &algorithm_config)) {
954     std::vector<AlgorithmDesc> algorithms;
955     CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
956         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
957         &algorithms));
958     std::vector<tensorflow::AutotuneResult> results;
959     for (auto profile_algorithm : algorithms) {
960       // TODO(zhengxq): profile each algorithm multiple times to better
961       // accuracy.
962       DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
963                                             ctx);
964       ProfileResult profile_result;
965       bool cudnn_launch_status =
966           stream
967               ->ThenConvolveBackwardDataWithAlgorithm(
968                   filter_desc, filter_ptr, output_desc, out_backprop_ptr,
969                   conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
970                   AlgorithmConfig(profile_algorithm), &profile_result)
971               .ok();
972       if (cudnn_launch_status) {
973         if (profile_result.is_valid()) {
974           results.emplace_back();
975           auto& result = results.back();
976           result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
977           result.mutable_conv()->set_tensor_ops_enabled(
978               profile_algorithm.tensor_ops_enabled());
979           result.mutable_success()->set_scratch_bytes(
980               scratch_allocator.TotalByteSize());
981           *result.mutable_success()->mutable_run_time() =
982               proto_utils::ToDurationProto(
983                   absl::Milliseconds(profile_result.elapsed_time_in_ms()));
984         }
985       }
986     }
987     LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop,
988                            transformed_filter, transformed_out_backprop,
989                            stream->parent(), results);
990     OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
991     AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
992                                                algorithm_config);
993   }
994   bool cudnn_launch_status =
995       stream
996           ->ThenConvolveBackwardDataWithAlgorithm(
997               filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
998               input_desc, &in_backprop_ptr, &scratch_allocator,
999               algorithm_config, nullptr)
1000           .ok();
1001 
1002   if (!cudnn_launch_status) {
1003     ctx->SetStatus(errors::Internal(
1004         "cuDNN Backward Data function launch failure : input shape(",
1005         input_shape.DebugString(), ") filter shape(",
1006         filter_shape.DebugString(), ")"));
1007     return;
1008   }
1009 
1010   if (padding_top != padding_bottom || padding_left != padding_right) {
1011     Tensor in_backprop_remove_padding;
1012     OP_REQUIRES_OK(
1013         ctx, ctx->allocate_temp(
1014                  DataTypeToEnum<T>::value,
1015                  ShapeFromFormat(FORMAT_NCHW,
1016                                  GetTensorDim(input_shape, data_format, 'N'),
1017                                  GetTensorDim(input_shape, data_format, 'H'),
1018                                  GetTensorDim(input_shape, data_format, 'W'),
1019                                  GetTensorDim(input_shape, data_format, 'C')),
1020                  &in_backprop_remove_padding));
1021 
1022     // Remove the padding that was added to the input shape above.
1023     const int64 input_pad_top = padding_top - common_padding_rows;
1024     const int64 input_pad_bottom = padding_bottom - common_padding_rows;
1025     const int64 input_pad_left = padding_left - common_padding_cols;
1026     const int64 input_pad_right = padding_right - common_padding_cols;
1027     functor::PadInput<GPUDevice, T, int, 4>()(
1028         ctx->template eigen_device<GPUDevice>(),
1029         To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1030                     .tensor<T, 4>()),
1031         {{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
1032         {{static_cast<int>(-input_pad_bottom),
1033           static_cast<int>(-input_pad_right)}},
1034         To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
1035 
1036     pre_transformed_in_backprop = in_backprop_remove_padding;
1037   }
1038 
1039   if (data_format == FORMAT_NHWC) {
1040     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1041     functor::NCHWToNHWC<GPUDevice, T, 4>()(
1042         ctx->eigen_device<GPUDevice>(),
1043         toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
1044         in_backprop->tensor<T, 4>());
1045   } else {
1046     *in_backprop = pre_transformed_in_backprop;
1047   }
1048 }
1049 
1050 // Forward declarations of the functor specializations for GPU.
1051 namespace functor {
1052 #define DECLARE_GPU_SPEC(T)                                              \
1053   template <>                                                            \
1054   void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()(              \
1055       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
1056       const Eigen::DSizes<int, 4>& order,                                \
1057       const Eigen::array<bool, 4>& reverse_dims,                         \
1058       typename TTypes<T, 4, int>::Tensor output);                        \
1059   extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>;        \
1060   template <>                                                            \
1061   void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()(           \
1062       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
1063       const Eigen::DSizes<int, 4>& strides,                              \
1064       const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims,            \
1065       const Eigen::DSizes<int, 4>& order,                                \
1066       typename TTypes<T, 4, int>::Tensor output);                        \
1067   extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>;     \
1068   template <>                                                            \
1069   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
1070       const GPUDevice& d, FilterTensorFormat dst_filter_format,          \
1071       typename TTypes<T, 4, int>::ConstTensor in,                        \
1072       typename TTypes<T, 4, int>::Tensor out);                           \
1073   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
1074   template <>                                                            \
1075   void TransformDepth<GPUDevice, T, int>::operator()(                    \
1076       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
1077       const Eigen::DSizes<int, 4>& shuffle,                              \
1078       typename TTypes<T, 4, int>::Tensor out);                           \
1079   extern template struct TransformDepth<GPUDevice, T, int>;              \
1080   template <>                                                            \
1081   void PadInput<GPUDevice, T, int, 4>::operator()(                       \
1082       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
1083       const std::array<int, 2>& padding_left,                            \
1084       const std::array<int, 2>& padding_right,                           \
1085       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
1086   extern template struct PadInput<GPUDevice, T, int, 4>;
1087 
1088 DECLARE_GPU_SPEC(float);
1089 DECLARE_GPU_SPEC(Eigen::half);
1090 DECLARE_GPU_SPEC(double);
1091 #undef DECLARE_GPU_SPEC
1092 }  // namespace functor
1093 
1094 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1095                             .Device(DEVICE_GPU)
1096                             .TypeConstraint<double>("T")
1097                             .HostMemory("input_sizes"),
1098                         Conv2DSlowBackpropInputOp<GPUDevice, double>);
1099 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1100                             .Device(DEVICE_GPU)
1101                             .TypeConstraint<float>("T")
1102                             .HostMemory("input_sizes"),
1103                         Conv2DSlowBackpropInputOp<GPUDevice, float>);
1104 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1105                             .Device(DEVICE_GPU)
1106                             .TypeConstraint<Eigen::half>("T")
1107                             .HostMemory("input_sizes"),
1108                         Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
1109 
1110 // To be used inside depthwise_conv_grad_op.cc.
1111 // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
1112 template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
1113 template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
1114 template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
1115 
1116 #endif  // GOOGLE_CUDA
1117 
1118 }  // namespace tensorflow
1119