• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
19 #define TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
20 
21 #define USE_EIGEN_TENSOR
22 #define EIGEN_USE_THREADS
23 
24 #include <algorithm>
25 #include <limits>
26 #include <vector>
27 
28 #include "absl/base/dynamic_annotations.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/kernel_shape_util.h"
31 #include "tensorflow/core/framework/numeric_op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_slice.h"
37 #include "tensorflow/core/kernels/conv_2d.h"
38 #include "tensorflow/core/kernels/conv_grad_ops.h"
39 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
40 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
41 #include "tensorflow/core/kernels/xsmm_conv2d.h"
42 #endif
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/gtl/array_slice.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/macros.h"
47 #include "tensorflow/core/util/padding.h"
48 #include "tensorflow/core/util/tensor_format.h"
49 #include "tensorflow/core/util/use_cudnn.h"
50 #include "tensorflow/core/util/work_sharder.h"
51 
52 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
53 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
54 #endif
55 
56 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
57 #include "tensorflow/core/kernels/conv_ops_gpu.h"
58 #include "tensorflow/core/platform/stream_executor.h"
59 #include "tensorflow/core/util/proto/proto_utils.h"
60 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61 #if GOOGLE_CUDA
62 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
63 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
64 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
65 #endif  // GOOGLE_CUDA
66 
67 namespace tensorflow {
68 
69 typedef Eigen::ThreadPoolDevice CPUDevice;
70 typedef Eigen::GpuDevice GPUDevice;
71 
72 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
73 // order (height, width, depth), constructed from patches in 'col_data', which
74 // is required to be in storage order (out_height * out_width, filter_height,
75 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
76 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * __restrict im_data)77 void Col2im(const T* col_data, const int depth, const int height,
78             const int width, const int filter_h, const int filter_w,
79             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
80             const int stride_h, const int stride_w, T* __restrict im_data) {
81   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
82   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
83   int h_pad = -pad_t;
84   for (int h = 0; h < height_col; ++h) {
85     int w_pad = -pad_l;
86     for (int w = 0; w < width_col; ++w) {
87       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
88       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
89         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
90           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
91             for (int i = 0; i < depth; ++i) {
92               im_patch_data[i] += col_data[i];
93             }
94           }
95           im_patch_data += depth;
96           col_data += depth;
97         }
98         // Jump over remaining number of depth.
99         im_patch_data += depth * (width - filter_w);
100       }
101       w_pad += stride_w;
102     }
103     h_pad += stride_h;
104   }
105 }
106 
107 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU
108 // and GPU (for int32 only).
109 template <typename Device, typename T>
110 struct LaunchConv2DBackpropInputOpImpl {
operatorLaunchConv2DBackpropInputOpImpl111   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
112                   const Tensor& out_backprop, const Tensor& filter,
113                   int row_dilation, int col_dilation, int row_stride,
114                   int col_stride, const Padding& padding,
115                   const std::vector<int64>& explicit_paddings,
116                   Tensor* in_backprop, TensorFormat data_format) {
117     std::vector<int32> strides(4, 1);
118     std::vector<int32> dilations(4, 1);
119 
120     auto input_h = GetTensorDimIndex(data_format, 'H');
121     auto input_w = GetTensorDimIndex(data_format, 'W');
122     strides[input_h] = row_stride;
123     strides[input_w] = col_stride;
124     dilations[input_h] = row_dilation;
125     dilations[input_w] = col_dilation;
126 
127     const TensorShape& input_shape = in_backprop->shape();
128     const TensorShape& filter_shape = filter.shape();
129 
130     ConvBackpropDimensions dims;
131     OP_REQUIRES_OK(
132         ctx, ConvBackpropComputeDimensionsV2(
133                  "Conv2DBackpropInput", /*num_spatial_dims=*/2, input_shape,
134                  filter_shape, out_backprop.shape(), dilations, strides,
135                  padding, explicit_paddings, data_format, &dims));
136 
137     int64_t padding_top = -1, padding_bottom = -1;
138     int64_t padding_left = -1, padding_right = -1;
139     if (padding == EXPLICIT) {
140       GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
141                                &padding_top, &padding_bottom);
142       GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
143                                &padding_left, &padding_right);
144     }
145 
146     int64_t expected_out_rows, expected_out_cols;
147     // The function is guaranteed to succeed because we checked the output and
148     // padding was valid earlier.
149     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
150         dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
151         row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
152         &padding_bottom));
153     DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
154 
155     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
156         dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
157         col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
158         &padding_right));
159     DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
160 
161     if (std::is_same<Device, GPUDevice>::value) {
162       int64_t size = 1;
163 #define REQUIRES_32BIT(x)                                                   \
164   size *= x;                                                                \
165   OP_REQUIRES(ctx,                                                          \
166               FastBoundsCheck(x, std::numeric_limits<int32>::max()) &&      \
167                   FastBoundsCheck(size, std::numeric_limits<int32>::max()), \
168               errors::InvalidArgument("Tensor too large"))
169 
170       REQUIRES_32BIT(in_backprop->dim_size(0));
171       REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom);
172       REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right);
173       REQUIRES_32BIT(in_backprop->dim_size(3));
174 #undef REQUIRES_32BIT
175     }
176 
177     auto in_backprop_t = in_backprop->tensor<T, 4>();
178     auto out_backprop_t = out_backprop.tensor<T, 4>();
179     auto filter_t = filter.tensor<T, 4>();
180 
181     // WARNING: Need to swap row/col, padding_top/padding_left, and
182     // padding_bottom/padding_right when calling Eigen. Eigen expects tensors
183     // in NWHC format, but Tensorflow uses NHWC.
184 
185     if (padding != EXPLICIT) {
186       // If padding was not explicitly defined, Eigen spatial convolution
187       // backward input will infer correct forward paddings from input tensors.
188       functor::SpatialConvolutionBackwardInputFunc<Device, T>()(
189           ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
190           col_stride, row_stride, col_dilation, row_dilation);
191     } else {
192       functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Device,
193                                                                       T>()(
194           ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
195           in_backprop_t.dimension(2) + (padding_left + padding_right),
196           in_backprop_t.dimension(1) + (padding_top + padding_bottom),
197           col_stride, row_stride, col_dilation, row_dilation, padding_top,
198           padding_left);
199     }
200   }
201 };
202 
203 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU.
204 template <typename T>
205 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
206   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
207                   const Tensor& out_backprop, const Tensor& filter,
208                   int row_dilation, int col_dilation, int row_stride,
209                   int col_stride, const Padding& padding,
210                   const std::vector<int64>& explicit_paddings,
211                   Tensor* in_backprop, TensorFormat data_format) {
212     LaunchConv2DBackpropInputOpImpl<CPUDevice, T> launcher;
213     launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
214              row_dilation, col_dilation, row_stride, col_stride, padding,
215              explicit_paddings, in_backprop, data_format);
216   }
217 };
218 
219 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
220 template <typename Device, class T>
221 struct LaunchXsmmBackwardInputConvolution {
222   bool operator()(OpKernelContext* context, const Device& d,
223                   typename TTypes<T, 4>::Tensor input_backward,
224                   typename TTypes<T, 4>::ConstTensor kernel,
225                   typename TTypes<T, 4>::ConstTensor output_backward,
226                   int input_rows, int input_cols, int row_stride,
227                   int col_stride, int pad_h, int pad_w,
228                   TensorFormat data_format) const {
229     return false;
230   }
231 };
232 
233 template <>
234 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
235   bool operator()(OpKernelContext* context, const CPUDevice& d,
236                   typename TTypes<float, 4>::Tensor input_backward,
237                   typename TTypes<float, 4>::ConstTensor kernel,
238                   typename TTypes<float, 4>::ConstTensor output_backward,
239                   int input_rows, int input_cols, int row_stride,
240                   int col_stride, int pad_h, int pad_w,
241                   TensorFormat data_format) const {
242     auto batch = input_backward.dimension(0);
243     auto in_depth = input_backward.dimension(3);
244     auto out_depth = output_backward.dimension(3);
245     auto filter_rows = kernel.dimension(0);
246     auto filter_cols = kernel.dimension(1);
247     auto num_threads =
248         context->device()->tensorflow_cpu_worker_threads()->num_threads;
249     // See libxsmm_dnn.h for this struct definition.
250     libxsmm_dnn_conv_desc desc;
251     desc.N = batch;
252     desc.C = in_depth;
253     desc.H = input_rows;
254     desc.W = input_cols;
255     desc.K = out_depth;
256     desc.R = filter_rows;
257     desc.S = filter_cols;
258     desc.u = row_stride;
259     desc.v = col_stride;
260     desc.pad_h = pad_h;
261     desc.pad_w = pad_w;
262     desc.pad_h_in = 0;
263     desc.pad_w_in = 0;
264     desc.pad_h_out = 0;
265     desc.pad_w_out = 0;
266     desc.threads = num_threads;
267     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
268     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
269     desc.filter_format =
270         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;  // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
271     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
272     desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
273     desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
274     desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
275     auto input_ptr = input_backward.data();
276     auto filter_ptr = kernel.data();
277     auto output_ptr = output_backward.data();
278 
279     bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
280         context, desc, input_ptr, filter_ptr, output_ptr);
281     return success;
282   }
283 };
284 #endif
285 
286 template <typename T>
287 struct Conv2DCustomBackpropInputMatMulFunctor {
288   using MatrixMap = Eigen::Map<
289       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
290   using ConstMatrixMap = Eigen::Map<
291       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
292 
293   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
294                   const int filter_total_size, const int output_image_size,
295                   const int dims_out_depth, T* im2col_buf) {
296     // Compute gradient into 'im2col_buf'.
297     MatrixMap C(im2col_buf, output_image_size, filter_total_size);
298 
299     ConstMatrixMap A(out_data, output_image_size, dims_out_depth);
300     ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth);
301 
302     C.noalias() = A * B.transpose();
303   }
304 };
305 
306 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
307 template <>
308 struct Conv2DCustomBackpropInputMatMulFunctor<float> {
309   using T = float;
310 
311   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
312                   const int filter_total_size, const int output_image_size,
313                   const int dims_out_depth, T* im2col_buf) {
314     // Inputs are in RowMajor order.
315     //   im2col      = out_data    * filter_data^T
316     //   [ois x fts] = [ois x dod] * [fts x dod]^T
317     //
318     // Dimension names:
319     //   out_image_size    -> ois
320     //   filter_total_size -> fts
321     //   dims_out_depth    -> dod
322 
323     const int m = output_image_size;
324     const int n = filter_total_size;
325     const int k = dims_out_depth;  // contraction dim
326 
327     const char transposeA = 'N';  // sgemm(A) == filter_data
328     const char transposeB = 'T';  // sgemm(B) == out_data
329 
330     const int ldA = dims_out_depth;
331     const int ldB = dims_out_depth;
332     const int ldC = filter_total_size;
333 
334     const float alpha = 1.0;
335     const float beta = 0.0;
336 
337     // dnnl_sgemm code can't be instrumented with msan.
338     ANNOTATE_MEMORY_IS_INITIALIZED(
339         im2col_buf, filter_total_size * output_image_size * sizeof(T));
340 
341     dnnl_status_t st =
342         dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA,
343                    filter_data, ldB, beta, im2col_buf, ldC);
344 
345     OP_REQUIRES(
346         ctx, st == 0,
347         errors::Internal("Failed to call dnnl_sgemm. Error code: ", st));
348   }
349 };
350 #endif
351 
352 template <typename Device, class T>
353 class Conv2DBackpropInputOp : public OpKernel {
354  public:
355   explicit Conv2DBackpropInputOp(OpKernelConstruction* context)
356       : OpKernel(context) {
357     string data_format;
358     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
359     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
360                 errors::InvalidArgument("Invalid data format"));
361 
362     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
363     OP_REQUIRES(context, strides_.size() == 4,
364                 errors::InvalidArgument("Sliding window strides field must "
365                                         "specify 4 dimensions"));
366     int stride_n = GetTensorDim(strides_, data_format_, 'N');
367     int stride_c = GetTensorDim(strides_, data_format_, 'C');
368     int stride_h = GetTensorDim(strides_, data_format_, 'H');
369     int stride_w = GetTensorDim(strides_, data_format_, 'W');
370     OP_REQUIRES(
371         context, (stride_n == 1 && stride_c == 1),
372         errors::Unimplemented("Current implementation does not yet support "
373                               "strides in the batch and depth dimensions."));
374     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
375                 errors::InvalidArgument(
376                     "Row and column strides should be larger than 0."));
377 
378     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
379     OP_REQUIRES(context, dilations_.size() == 4,
380                 errors::InvalidArgument("Sliding window dilations field must "
381                                         "specify 4 dimensions"));
382     int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
383     int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
384     int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
385     int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
386     OP_REQUIRES(
387         context, (dilation_n == 1 && dilation_c == 1),
388         errors::Unimplemented("Current implementation does not yet support "
389                               "dilations in the batch and depth dimensions."));
390     OP_REQUIRES(
391         context, dilation_h > 0 && dilation_w > 0,
392         errors::InvalidArgument("Dilated rates should be larger than 0."));
393 
394     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
395     OP_REQUIRES_OK(context,
396                    context->GetAttr("explicit_paddings", &explicit_paddings_));
397     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
398                                               /*num_dims=*/4, data_format_));
399 
400     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
401     cudnn_use_autotune_ = CudnnUseAutotune();
402 
403     if (std::is_same<Device, CPUDevice>::value ||
404         std::is_same<T, int32>::value) {
405       OP_REQUIRES(
406           context, data_format_ == FORMAT_NHWC,
407           errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] "
408                                   "only supports NHWC data format."));
409 
410       // TODO(yangzihao): Add a CPU implementation for dilated convolution.
411       OP_REQUIRES(
412           context, (dilation_h == 1 && dilation_w == 1),
413           errors::InvalidArgument(
414               "Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support "
415               "dilation rates larger than 1."));
416     }
417   }
418 
419   void Compute(OpKernelContext* context) override {
420     const Tensor& input_sizes = context->input(0);
421     const Tensor& filter = context->input(1);
422     const Tensor& out_backprop = context->input(2);
423 
424     TensorShape input_shape;
425     OP_REQUIRES_OK(context,
426                    Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
427                                                    out_backprop.shape(),
428                                                    data_format_, &input_shape));
429 
430     Tensor* in_backprop = nullptr;
431     OP_REQUIRES_OK(context,
432                    context->allocate_output(0, input_shape, &in_backprop));
433 
434     // If there is nothing to compute, return.
435     if (input_shape.num_elements() == 0) {
436       return;
437     }
438 
439     // For now we take the stride from the second and third dimensions only (we
440     // do not support striding on the batch or depth dimension).
441     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
442     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
443     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
444     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
445 
446     VLOG(2) << "Conv2DBackpropInput:"
447             << " input: " << input_shape.DebugString()
448             << " filter:" << filter.shape().DebugString()
449             << " out_backprop: " << out_backprop.shape().DebugString()
450             << " strides: [" << stride_rows << ", " << stride_cols << "]"
451             << " dilations: [" << dilation_rows << ", " << dilation_cols << "]";
452 
453     LaunchConv2DBackpropInputOp<Device, T> launch;
454     launch(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
455            dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
456            explicit_paddings_, in_backprop, data_format_);
457   }
458 
459  private:
460   std::vector<int32> dilations_;
461   std::vector<int32> strides_;
462   TensorFormat data_format_;
463   Padding padding_;
464   std::vector<int64> explicit_paddings_;
465 
466   bool use_cudnn_ = false;
467   bool cudnn_use_autotune_ = false;
468 
469   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp);
470 };
471 
472 // Based on implementation written by Yangqing Jia (jiayq).
473 template <typename Device, class T>
474 class Conv2DCustomBackpropInputOp : public OpKernel {
475  public:
476   explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
477       : OpKernel(context) {
478     string data_format;
479     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
480     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
481                 errors::InvalidArgument("Invalid data format"));
482     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
483                 errors::InvalidArgument(
484                     "Conv2DCustomBackpropInputOp only supports NHWC."));
485     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
486     OP_REQUIRES(context, strides_.size() == 4,
487                 errors::InvalidArgument("Sliding window strides field must "
488                                         "specify 4 dimensions"));
489     OP_REQUIRES(
490         context, (strides_[0] == 1 && strides_[3] == 1),
491         errors::Unimplemented("Current implementation does not yet support "
492                               "strides in the batch and depth dimensions."));
493     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
494                 errors::InvalidArgument(
495                     "Row and column strides should be larger than 0."));
496     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
497     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
498     OP_REQUIRES(context, dilations_.size() == 4,
499                 errors::InvalidArgument("Sliding window dilations field must "
500                                         "specify 4 dimensions"));
501     OP_REQUIRES(
502         context, (dilations_[0] == 1 && dilations_[3] == 1),
503         errors::Unimplemented("Current implementation does not yet support "
504                               "dilations in the batch and depth dimensions."));
505     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
506     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
507                 errors::InvalidArgument(
508                     "Current libxsmm and customized CPU implementations do "
509                     "not yet support dilation rates larger than 1."));
510     OP_REQUIRES_OK(context,
511                    context->GetAttr("explicit_paddings", &explicit_paddings_));
512     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
513                                               /*num_dims=*/4, data_format_));
514   }
515 
516   void Compute(OpKernelContext* context) override {
517     const Tensor& input_sizes = context->input(0);
518     const Tensor& filter = context->input(1);
519     const Tensor& out_backprop = context->input(2);
520 
521     TensorShape input_shape;
522     OP_REQUIRES_OK(context,
523                    Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
524                                                    out_backprop.shape(),
525                                                    data_format_, &input_shape));
526 
527     ConvBackpropDimensions dims;
528     OP_REQUIRES_OK(context,
529                    ConvBackpropComputeDimensionsV2(
530                        "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
531                        input_shape, filter.shape(), out_backprop.shape(),
532                        /*dilations=*/{1, 1, 1, 1}, strides_, padding_,
533                        explicit_paddings_, data_format_, &dims));
534 
535     OP_REQUIRES(context, dims.in_depth == filter.shape().dim_size(2),
536                 errors::InvalidArgument("Computed input depth ", dims.in_depth,
537                                         " doesn't match filter input depth ",
538                                         filter.shape().dim_size(2)));
539     OP_REQUIRES(
540         context, dims.out_depth == filter.shape().dim_size(3),
541         errors::InvalidArgument("Computed output depth ", dims.out_depth,
542                                 " doesn't match filter output depth ",
543                                 filter.shape().dim_size(3)));
544 
545     Tensor* in_backprop = nullptr;
546     OP_REQUIRES_OK(context,
547                    context->allocate_output(0, input_shape, &in_backprop));
548 
549     // If there is nothing to compute, return.
550     if (input_shape.num_elements() == 0) {
551       return;
552     }
553 
554 // TODO(ezhulenev): Remove custom kernel and move XSMM support to
555 // LaunchConv2DBackpropInputOp functor.
556 #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
557     defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
558     int64 pad_top, pad_bottom;
559     int64 pad_left, pad_right;
560     OP_REQUIRES_OK(
561         context,
562         GetWindowedOutputSizeVerbose(
563             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
564             dims.spatial_dims[0].stride, padding_,
565             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
566     OP_REQUIRES_OK(
567         context,
568         GetWindowedOutputSizeVerbose(
569             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
570             dims.spatial_dims[1].stride, padding_,
571             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
572 
573     if (pad_left == pad_right && pad_top == pad_bottom) {
574       if (LaunchXsmmBackwardInputConvolution<Device, T>()(
575               context, context->eigen_device<Device>(),
576               in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
577               out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
578               dims.spatial_dims[1].input_size,
579               static_cast<int>(dims.spatial_dims[0].stride),
580               static_cast<int>(dims.spatial_dims[1].stride),
581               static_cast<int>(pad_top), static_cast<int>(pad_left),
582               data_format_)) {
583         return;
584       }
585     }
586 #else
587     int64_t pad_top, pad_bottom;
588     int64_t pad_left, pad_right;
589 #endif
590     if (padding_ == Padding::EXPLICIT) {
591       pad_top = explicit_paddings_[2];
592       pad_bottom = explicit_paddings_[3];
593       pad_left = explicit_paddings_[4];
594       pad_right = explicit_paddings_[5];
595     }
596     OP_REQUIRES_OK(
597         context,
598         GetWindowedOutputSizeVerbose(
599             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
600             dims.spatial_dims[0].stride, padding_,
601             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
602     OP_REQUIRES_OK(
603         context,
604         GetWindowedOutputSizeVerbose(
605             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
606             dims.spatial_dims[1].stride, padding_,
607             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
608 
609     // The total dimension size of each kernel.
610     const int filter_total_size = dims.spatial_dims[0].filter_size *
611                                   dims.spatial_dims[1].filter_size *
612                                   dims.in_depth;
613     // The output image size is the spatial size of the output.
614     const int output_image_size =
615         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
616 
617     // TODO(andydavis) Get L2/L3 cache sizes from device.
618     const size_t l2_cache_size = 256LL << 10;
619     const size_t l3_cache_size = 30LL << 20;
620 
621     // Use L3 cache size as target working set size.
622     const size_t target_working_set_size = l3_cache_size / sizeof(T);
623 
624     // Calculate size of matrices involved in MatMul: C = A x B.
625     const size_t size_A = output_image_size * dims.out_depth;
626 
627     const size_t size_B = filter_total_size * dims.out_depth;
628 
629     const size_t size_C = output_image_size * filter_total_size;
630 
631     const size_t work_unit_size = size_A + size_B + size_C;
632 
633     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
634 
635     // Calculate per-thread work unit size.
636     const size_t thread_work_unit_size =
637         work_unit_size / worker_threads.num_threads;
638 
639     // Set minimum per-thread work unit size to size of L2 cache.
640     const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
641 
642     // Use parallel tensor contractions if there is no batching, or if the
643     // minimum per-thread work unit size threshold has been exceeded.
644     // Otherwise, revert to multiple single-threaded matmul ops running in
645     // parallel to keep all threads busy.
646     // TODO(andydavis) Explore alternatives to branching the code in this way
647     // (i.e. run multiple, parallel tensor contractions in another thread pool).
648     const bool use_parallel_contraction =
649         dims.batch_size == 1 ||
650         thread_work_unit_size >= min_thread_work_unit_size;
651 
652     OP_REQUIRES(
653         context, work_unit_size > 0,
654         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
655                                 "must all have at least 1 element"));
656 
657     const size_t shard_size =
658         use_parallel_contraction
659             ? 1
660             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
661 
662     Tensor col_buffer;
663     OP_REQUIRES_OK(context,
664                    context->allocate_temp(
665                        DataTypeToEnum<T>::value,
666                        TensorShape({static_cast<int64>(shard_size),
667                                     static_cast<int64>(output_image_size),
668                                     static_cast<int64>(filter_total_size)}),
669                        &col_buffer));
670 
671     // The input offset corresponding to a single input image.
672     const int input_offset = dims.spatial_dims[0].input_size *
673                              dims.spatial_dims[1].input_size * dims.in_depth;
674     // The output offset corresponding to a single output image.
675     const int output_offset = dims.spatial_dims[0].output_size *
676                               dims.spatial_dims[1].output_size * dims.out_depth;
677 
678     const T* filter_data = filter.template flat<T>().data();
679     T* col_buffer_data = col_buffer.template flat<T>().data();
680     const T* out_backprop_data = out_backprop.template flat<T>().data();
681 
682     auto in_backprop_flat = in_backprop->template flat<T>();
683     T* input_backprop_data = in_backprop_flat.data();
684     in_backprop_flat.device(context->eigen_device<Device>()) =
685         in_backprop_flat.constant(T(0));
686 
687     if (use_parallel_contraction) {
688       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
689                                Eigen::Unaligned>
690           TensorMap;
691       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
692                                Eigen::Unaligned>
693           ConstTensorMap;
694 
695       // Initialize contraction dims (we need to transpose 'B' below).
696       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
697       contract_dims[0].first = 1;
698       contract_dims[0].second = 1;
699 
700       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
701         // Compute gradient into col_buffer.
702         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
703 
704         ConstTensorMap A(out_backprop_data + output_offset * image_id,
705                          output_image_size, dims.out_depth);
706         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
707 
708         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
709 
710         Col2im<T>(
711             col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
712             dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
713             dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
714             pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
715             input_backprop_data);
716 
717         input_backprop_data += input_offset;
718       }
719     } else {
720       for (int image_id = 0; image_id < dims.batch_size;
721            image_id += shard_size) {
722         const int shard_limit =
723             std::min(static_cast<int>(shard_size),
724                      static_cast<int>(dims.batch_size) - image_id);
725 
726         auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom,
727                       &pad_right, &output_image_size, &filter_total_size,
728                       &input_backprop_data, &col_buffer_data,
729                       &out_backprop_data, &filter_data, &input_offset,
730                       &output_offset, &size_C](int64_t start, int64_t limit) {
731           for (int shard_id = start; shard_id < limit; ++shard_id) {
732             T* im2col_buf = col_buffer_data + shard_id * size_C;
733             T* input_data = input_backprop_data + shard_id * input_offset;
734             const T* out_data = out_backprop_data + shard_id * output_offset;
735 
736             Conv2DCustomBackpropInputMatMulFunctor<T>()(
737                 context, out_data, filter_data, filter_total_size,
738                 output_image_size, dims.out_depth, im2col_buf);
739 
740             Col2im<T>(im2col_buf, dims.in_depth,
741                       dims.spatial_dims[0].input_size,
742                       dims.spatial_dims[1].input_size,
743                       dims.spatial_dims[0].filter_size,
744                       dims.spatial_dims[1].filter_size, pad_top, pad_left,
745                       pad_bottom, pad_right, dims.spatial_dims[0].stride,
746                       dims.spatial_dims[1].stride, input_data);
747           }
748         };
749         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
750               work_unit_size, shard);
751 
752         input_backprop_data += input_offset * shard_limit;
753         out_backprop_data += output_offset * shard_limit;
754       }
755     }
756   }
757 
758  private:
759   std::vector<int32> dilations_;
760   std::vector<int32> strides_;
761   Padding padding_;
762   std::vector<int64> explicit_paddings_;
763   TensorFormat data_format_;
764 
765   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
766 };
767 
768 // TODO(ezhulenev): Add a cost model to switch between custom/Eigen ops.
769 #define DEFAULT_CONV_2D_BACKPROP_CPU_OP Conv2DCustomBackpropInputOp
770 
771 #define REGISTER_CONV_2D_BACKPROP_CPU_KERNELS(T)                             \
772   REGISTER_KERNEL_BUILDER(                                                   \
773       Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
774       DEFAULT_CONV_2D_BACKPROP_CPU_OP<CPUDevice, T>);                        \
775   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
776                               .Device(DEVICE_CPU)                            \
777                               .Label("custom")                               \
778                               .TypeConstraint<T>("T"),                       \
779                           Conv2DCustomBackpropInputOp<CPUDevice, T>);        \
780   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
781                               .Device(DEVICE_CPU)                            \
782                               .Label("eigen_tensor")                         \
783                               .TypeConstraint<T>("T"),                       \
784                           Conv2DBackpropInputOp<CPUDevice, T>);
785 
786 }  // namespace tensorflow
787 
788 #endif  // TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
789