1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/nn_ops.cc.
17
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20
21 #include "tensorflow/core/kernels/conv_grad_ops.h"
22
23 #include <algorithm>
24 #include <vector>
25
26 #include "absl/base/dynamic_annotations.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_slice.h"
33 #include "tensorflow/core/kernels/conv_2d.h"
34 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
35 #include "tensorflow/core/kernels/xsmm_conv2d.h"
36 #endif
37 #include "tensorflow/core/kernels/ops_util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/array_slice.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/util/padding.h"
43 #include "tensorflow/core/util/tensor_format.h"
44 #include "tensorflow/core/util/use_cudnn.h"
45 #include "tensorflow/core/util/work_sharder.h"
46
47 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
48 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
49 #endif
50
51 #if GOOGLE_CUDA
52 #include "tensorflow/core/kernels/conv_ops_gpu.h"
53 #include "tensorflow/core/platform/stream_executor.h"
54 #include "tensorflow/core/protobuf/autotuning.pb.h"
55 #include "tensorflow/core/util/proto/proto_utils.h"
56 #endif // GOOGLE_CUDA
57
58 namespace {
59
60 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
61 // order (height, width, depth), constructed from patches in 'col_data', which
62 // is required to be in storage order (out_height * out_width, filter_height,
63 // filter_width, in_depth). Implementation by Yangqing Jia (jiayq).
64 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * im_data)65 void Col2im(const T* col_data, const int depth, const int height,
66 const int width, const int filter_h, const int filter_w,
67 const int pad_t, const int pad_l, const int pad_b, const int pad_r,
68 const int stride_h, const int stride_w, T* im_data) {
69 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
70 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
71 int h_pad = -pad_t;
72 for (int h = 0; h < height_col; ++h) {
73 int w_pad = -pad_l;
74 for (int w = 0; w < width_col; ++w) {
75 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
76 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
77 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
78 if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
79 // TODO(andydavis) Vectorize this loop (if compiler does not).
80 for (int i = 0; i < depth; ++i) {
81 im_patch_data[i] += col_data[i];
82 }
83 }
84 im_patch_data += depth;
85 col_data += depth;
86 }
87 // Jump over remaining number of depth.
88 im_patch_data += depth * (width - filter_w);
89 }
90 w_pad += stride_w;
91 }
92 h_pad += stride_h;
93 }
94 }
95
96 } // namespace
97
98 namespace tensorflow {
99
100 typedef Eigen::ThreadPoolDevice CPUDevice;
101 typedef Eigen::GpuDevice GPUDevice;
102
103 // The fast versions using eigen computations directly. They are only enabled
104 // for CPU for now since nvcc times out when trying to compile them.
105 // TODO(yangke): enable them for GPUs when we have a faster compiler.
106
107 template <typename T>
108 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DBackpropInputOp109 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
110 const Tensor& out_backprop, const Tensor& filter,
111 int row_dilation, int col_dilation, int row_stride,
112 int col_stride, const Padding& padding,
113 const std::vector<int64>& explicit_paddings,
114 Tensor* in_backprop, TensorFormat data_format) {
115 const CPUDevice& d = ctx->eigen_device<CPUDevice>();
116 functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
117 d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
118 out_backprop.tensor<T, 4>(), row_stride, col_stride,
119 /*row_dilation=*/1, /*col_dilation=*/1);
120 }
121 };
122
123 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
124 template <typename Device, class T>
125 struct LaunchXsmmBackwardInputConvolution {
operator ()tensorflow::LaunchXsmmBackwardInputConvolution126 bool operator()(OpKernelContext* context, const Device& d,
127 typename TTypes<T, 4>::Tensor input_backward,
128 typename TTypes<T, 4>::ConstTensor kernel,
129 typename TTypes<T, 4>::ConstTensor output_backward,
130 int input_rows, int input_cols, int row_stride,
131 int col_stride, int pad_h, int pad_w,
132 TensorFormat data_format) const {
133 return false;
134 }
135 };
136
137 template <>
138 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
operator ()tensorflow::LaunchXsmmBackwardInputConvolution139 bool operator()(OpKernelContext* context, const CPUDevice& d,
140 typename TTypes<float, 4>::Tensor input_backward,
141 typename TTypes<float, 4>::ConstTensor kernel,
142 typename TTypes<float, 4>::ConstTensor output_backward,
143 int input_rows, int input_cols, int row_stride,
144 int col_stride, int pad_h, int pad_w,
145 TensorFormat data_format) const {
146 auto batch = input_backward.dimension(0);
147 auto in_depth = input_backward.dimension(3);
148 auto out_depth = output_backward.dimension(3);
149 auto filter_rows = kernel.dimension(0);
150 auto filter_cols = kernel.dimension(1);
151 auto num_threads =
152 context->device()->tensorflow_cpu_worker_threads()->num_threads;
153 // See libxsmm_dnn.h for this struct definition.
154 libxsmm_dnn_conv_desc desc;
155 desc.N = batch;
156 desc.C = in_depth;
157 desc.H = input_rows;
158 desc.W = input_cols;
159 desc.K = out_depth;
160 desc.R = filter_rows;
161 desc.S = filter_cols;
162 desc.u = row_stride;
163 desc.v = col_stride;
164 desc.pad_h = pad_h;
165 desc.pad_w = pad_w;
166 desc.pad_h_in = 0;
167 desc.pad_w_in = 0;
168 desc.pad_h_out = 0;
169 desc.pad_w_out = 0;
170 desc.threads = num_threads;
171 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
172 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
173 desc.filter_format =
174 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
175 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
176 desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
177 desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
178
179 auto input_ptr = input_backward.data();
180 auto filter_ptr = kernel.data();
181 auto output_ptr = output_backward.data();
182
183 bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
184 context, desc, input_ptr, filter_ptr, output_ptr);
185 return success;
186 }
187 };
188 #endif
189
190 template <typename T>
191 struct Conv2DCustomBackpropInputMatMulFunctor {
192 using MatrixMap = Eigen::Map<
193 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
194 using ConstMatrixMap = Eigen::Map<
195 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
196
operator ()tensorflow::Conv2DCustomBackpropInputMatMulFunctor197 void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
198 const int filter_total_size, const int output_image_size,
199 const int dims_out_depth, T* im2col_buf) {
200 // Compute gradient into 'im2col_buf'.
201 MatrixMap C(im2col_buf, output_image_size, filter_total_size);
202
203 ConstMatrixMap A(out_data, output_image_size, dims_out_depth);
204 ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth);
205
206 C.noalias() = A * B.transpose();
207 }
208 };
209
210 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
211 template <>
212 struct Conv2DCustomBackpropInputMatMulFunctor<float> {
213 using T = float;
214
operator ()tensorflow::Conv2DCustomBackpropInputMatMulFunctor215 void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
216 const int filter_total_size, const int output_image_size,
217 const int dims_out_depth, T* im2col_buf) {
218 // Inputs are in RowMajor order, we "cheat" by swapping the LHS and RHS:
219 // RowMajor: C = A * B
220 // ColMajor: C^T = B^T * A^T
221 //
222 // Dimension names:
223 // out_image_size -> ois
224 // filter_total_size -> fts
225 // dims_out_depth -> dod
226 //
227 // RowMajor:
228 // im2col = out_data * filter_data^T
229 // [ois x fts] = [ois x dod] * [fts x dod]^T
230 //
231 // ColMajor:
232 // im2col^T = filter_data * out_data^T
233 // [fts x ois] = [fts x dod] * [dod x ois]*
234
235 const int m = filter_total_size;
236 const int n = output_image_size;
237 const int k = dims_out_depth; // contraction dim
238
239 const char transposeA = 'T'; // sgemm(A) == filter_data
240 const char transposeB = 'N'; // sgemm(B) == out_data
241
242 const int ldA = dims_out_depth;
243 const int ldB = dims_out_depth;
244 const int ldC = filter_total_size;
245
246 const float alpha = 1.0;
247 const float beta = 0.0;
248
249 // mkldnn_sgemm code can't be instrumented with msan.
250 ANNOTATE_MEMORY_IS_INITIALIZED(
251 im2col_buf, filter_total_size * output_image_size * sizeof(T));
252
253 mkldnn_status_t st =
254 mkldnn_sgemm(&transposeA, &transposeB, &m, &n, &k, &alpha, filter_data,
255 &ldA, out_data, &ldB, &beta, im2col_buf, &ldC);
256
257 OP_REQUIRES(
258 ctx, st == 0,
259 errors::Internal("Failed to call mkldnn_sgemm. Error code: ", st));
260 }
261 };
262 #endif
263
264 // Based on implementation written by Yangqing Jia (jiayq).
265 template <typename Device, class T>
266 class Conv2DCustomBackpropInputOp : public OpKernel {
267 public:
Conv2DCustomBackpropInputOp(OpKernelConstruction * context)268 explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
269 : OpKernel(context) {
270 string data_format;
271 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
272 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
273 errors::InvalidArgument("Invalid data format"));
274 OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
275 errors::InvalidArgument(
276 "Conv2DCustomBackpropInputOp only supports NHWC."));
277 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
278 OP_REQUIRES(context, strides_.size() == 4,
279 errors::InvalidArgument("Sliding window strides field must "
280 "specify 4 dimensions"));
281 OP_REQUIRES(
282 context, (strides_[0] == 1 && strides_[3] == 1),
283 errors::InvalidArgument("Current implementation does not yet support "
284 "strides in the batch and depth dimensions."));
285 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
286 errors::InvalidArgument(
287 "Row and column strides should be larger than 0."));
288 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
289 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
290 OP_REQUIRES(context, dilations_.size() == 4,
291 errors::InvalidArgument("Sliding window dilations field must "
292 "specify 4 dimensions"));
293 OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
294 errors::InvalidArgument(
295 "Current implementation does not yet support "
296 "dilations in the batch and depth dimensions."));
297 // TODO(yangzihao): Add a CPU implementation for dilated convolution.
298 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
299 errors::InvalidArgument(
300 "Current libxsmm and customized CPU implementations do "
301 "not yet support dilation rates larger than 1."));
302 OP_REQUIRES(
303 context, padding_ != Padding::EXPLICIT,
304 errors::Unimplemented("Current CPU implementation does not support "
305 "EXPLICIT padding yet."));
306 std::vector<int64> explicit_paddings;
307 OP_REQUIRES_OK(context,
308 context->GetAttr("explicit_paddings", &explicit_paddings));
309 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
310 /*num_dims=*/4, data_format_));
311 }
312
Compute(OpKernelContext * context)313 void Compute(OpKernelContext* context) override {
314 const Tensor& input_sizes = context->input(0);
315 const Tensor& filter = context->input(1);
316 const Tensor& out_backprop = context->input(2);
317 OP_REQUIRES(
318 context, TensorShapeUtils::IsVector(input_sizes.shape()),
319 errors::InvalidArgument(
320 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
321 input_sizes.dims()));
322 TensorShape input_shape;
323 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
324 input_sizes.vec<int32>(), &input_shape));
325
326 ConvBackpropDimensions dims;
327 OP_REQUIRES_OK(context,
328 ConvBackpropComputeDimensions(
329 "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
330 input_shape, filter.shape(), out_backprop.shape(),
331 strides_, padding_, data_format_, &dims));
332
333 Tensor* in_backprop = nullptr;
334 OP_REQUIRES_OK(context,
335 context->allocate_output(0, input_shape, &in_backprop));
336
337 // If there is nothing to compute, return.
338 if (input_shape.num_elements() == 0) {
339 return;
340 }
341
342 // TODO(andydavis) Consider moving code shared with
343 // Conv2DCustomBackpropFilterOp into a shared helper function.
344 #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
345 defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
346 int64 pad_top, pad_bottom;
347 int64 pad_left, pad_right;
348 OP_REQUIRES_OK(
349 context,
350 GetWindowedOutputSizeVerbose(
351 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
352 dims.spatial_dims[0].stride, padding_,
353 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
354 OP_REQUIRES_OK(
355 context,
356 GetWindowedOutputSizeVerbose(
357 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
358 dims.spatial_dims[1].stride, padding_,
359 &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
360
361 if (pad_left == pad_right && pad_top == pad_bottom) {
362 if (LaunchXsmmBackwardInputConvolution<Device, T>()(
363 context, context->eigen_device<Device>(),
364 in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
365 out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
366 dims.spatial_dims[1].input_size,
367 static_cast<int>(dims.spatial_dims[0].stride),
368 static_cast<int>(dims.spatial_dims[1].stride),
369 static_cast<int>(pad_top), static_cast<int>(pad_left),
370 data_format_)) {
371 return;
372 }
373 }
374 #else
375 int64 pad_top, pad_bottom;
376 int64 pad_left, pad_right;
377 #endif
378 OP_REQUIRES_OK(
379 context,
380 GetWindowedOutputSizeVerbose(
381 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
382 dims.spatial_dims[0].stride, padding_,
383 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
384 OP_REQUIRES_OK(
385 context,
386 GetWindowedOutputSizeVerbose(
387 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
388 dims.spatial_dims[1].stride, padding_,
389 &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
390
391 // The total dimension size of each kernel.
392 const int filter_total_size = dims.spatial_dims[0].filter_size *
393 dims.spatial_dims[1].filter_size *
394 dims.in_depth;
395 // The output image size is the spatial size of the output.
396 const int output_image_size =
397 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
398
399 // TODO(andydavis) Get L2/L3 cache sizes from device.
400 const size_t l2_cache_size = 256LL << 10;
401 const size_t l3_cache_size = 30LL << 20;
402
403 // Use L3 cache size as target working set size.
404 const size_t target_working_set_size = l3_cache_size / sizeof(T);
405
406 // Calculate size of matrices involved in MatMul: C = A x B.
407 const size_t size_A = output_image_size * dims.out_depth;
408
409 const size_t size_B = filter_total_size * dims.out_depth;
410
411 const size_t size_C = output_image_size * filter_total_size;
412
413 const size_t work_unit_size = size_A + size_B + size_C;
414
415 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
416
417 // Calculate per-thread work unit size.
418 const size_t thread_work_unit_size =
419 work_unit_size / worker_threads.num_threads;
420
421 // Set minimum per-thread work unit size to size of L2 cache.
422 const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
423
424 // Use parallel tensor contractions if there is no batching, or if the
425 // minimum per-thread work unit size threshold has been exceeded.
426 // Otherwise, revert to multiple single-threaded matmul ops running in
427 // parallel to keep all threads busy.
428 // TODO(andydavis) Explore alternatives to branching the code in this way
429 // (i.e. run multiple, parallel tensor contractions in another thread pool).
430 const bool use_parallel_contraction =
431 dims.batch_size == 1 ||
432 thread_work_unit_size >= min_thread_work_unit_size;
433
434 const size_t shard_size =
435 use_parallel_contraction
436 ? 1
437 : (target_working_set_size + work_unit_size - 1) / work_unit_size;
438
439 Tensor col_buffer;
440 OP_REQUIRES_OK(context,
441 context->allocate_temp(
442 DataTypeToEnum<T>::value,
443 TensorShape({static_cast<int64>(shard_size),
444 static_cast<int64>(output_image_size),
445 static_cast<int64>(filter_total_size)}),
446 &col_buffer));
447
448 // The input offset corresponding to a single input image.
449 const int input_offset = dims.spatial_dims[0].input_size *
450 dims.spatial_dims[1].input_size * dims.in_depth;
451 // The output offset corresponding to a single output image.
452 const int output_offset = dims.spatial_dims[0].output_size *
453 dims.spatial_dims[1].output_size * dims.out_depth;
454
455 const T* filter_data = filter.template flat<T>().data();
456 T* col_buffer_data = col_buffer.template flat<T>().data();
457 const T* out_backprop_data = out_backprop.template flat<T>().data();
458
459 auto in_backprop_flat = in_backprop->template flat<T>();
460 T* input_backprop_data = in_backprop_flat.data();
461 in_backprop_flat.device(context->eigen_device<Device>()) =
462 in_backprop_flat.constant(T(0));
463
464 if (use_parallel_contraction) {
465 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
466 Eigen::Unaligned>
467 TensorMap;
468 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
469 Eigen::Unaligned>
470 ConstTensorMap;
471
472 // Initialize contraction dims (we need to transpose 'B' below).
473 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
474 contract_dims[0].first = 1;
475 contract_dims[0].second = 1;
476
477 for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
478 // Compute gradient into col_buffer.
479 TensorMap C(col_buffer_data, output_image_size, filter_total_size);
480
481 ConstTensorMap A(out_backprop_data + output_offset * image_id,
482 output_image_size, dims.out_depth);
483 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
484
485 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
486
487 Col2im<T>(
488 col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
489 dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
490 dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
491 pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
492 input_backprop_data);
493
494 input_backprop_data += input_offset;
495 }
496 } else {
497 for (int image_id = 0; image_id < dims.batch_size;
498 image_id += shard_size) {
499 const int shard_limit =
500 std::min(static_cast<int>(shard_size),
501 static_cast<int>(dims.batch_size) - image_id);
502
503 auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom,
504 &pad_right, &output_image_size, &filter_total_size,
505 &input_backprop_data, &col_buffer_data,
506 &out_backprop_data, &filter_data, &input_offset,
507 &output_offset, &size_C](int64 start, int64 limit) {
508 for (int shard_id = start; shard_id < limit; ++shard_id) {
509 T* im2col_buf = col_buffer_data + shard_id * size_C;
510 T* input_data = input_backprop_data + shard_id * input_offset;
511 const T* out_data = out_backprop_data + shard_id * output_offset;
512
513 Conv2DCustomBackpropInputMatMulFunctor<T>()(
514 context, out_data, filter_data, filter_total_size,
515 output_image_size, dims.out_depth, im2col_buf);
516
517 Col2im<T>(im2col_buf, dims.in_depth,
518 dims.spatial_dims[0].input_size,
519 dims.spatial_dims[1].input_size,
520 dims.spatial_dims[0].filter_size,
521 dims.spatial_dims[1].filter_size, pad_top, pad_left,
522 pad_bottom, pad_right, dims.spatial_dims[0].stride,
523 dims.spatial_dims[1].stride, input_data);
524 }
525 };
526 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
527 work_unit_size, shard);
528
529 input_backprop_data += input_offset * shard_limit;
530 out_backprop_data += output_offset * shard_limit;
531 }
532 }
533 }
534
535 private:
536 std::vector<int32> dilations_;
537 std::vector<int32> strides_;
538 Padding padding_;
539 TensorFormat data_format_;
540
541 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
542 };
543
544 #define REGISTER_CPU_KERNELS(T) \
545 REGISTER_KERNEL_BUILDER( \
546 Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
547 Conv2DCustomBackpropInputOp<CPUDevice, T>); \
548 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \
549 .Device(DEVICE_CPU) \
550 .Label("custom") \
551 .TypeConstraint<T>("T"), \
552 Conv2DCustomBackpropInputOp<CPUDevice, T>);
553
554 TF_CALL_half(REGISTER_CPU_KERNELS);
555 TF_CALL_float(REGISTER_CPU_KERNELS);
556 TF_CALL_double(REGISTER_CPU_KERNELS);
557 #undef REGISTER_CPU_KERNELS
558
559 // To be used inside depthwise_conv_grad_op.cc.
560 template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
561 template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
562 template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
563
564 // GPU definitions.
565 #if GOOGLE_CUDA
566 // The slow version (but compiles for GPU)
567
568 // A dummy type to group forward backward data autotune results together.
569 struct ConvBackwardDataAutoTuneGroup {
nametensorflow::ConvBackwardDataAutoTuneGroup570 static string name() { return "ConvBwdData"; }
571 };
572 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
573 se::dnn::AlgorithmConfig>
574 AutoTuneConvBwdData;
575
576 // Backprop for input.
577 template <typename Device, class T>
578 class Conv2DSlowBackpropInputOp : public OpKernel {
579 public:
Conv2DSlowBackpropInputOp(OpKernelConstruction * context)580 explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context)
581 : OpKernel(context) {
582 string data_format;
583 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
584 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
585 errors::InvalidArgument("Invalid data format"));
586 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
587 OP_REQUIRES(context, strides_.size() == 4,
588 errors::InvalidArgument("Sliding window strides field must "
589 "specify 4 dimensions"));
590 int stride_n = GetTensorDim(strides_, data_format_, 'N');
591 int stride_c = GetTensorDim(strides_, data_format_, 'C');
592 int stride_h = GetTensorDim(strides_, data_format_, 'H');
593 int stride_w = GetTensorDim(strides_, data_format_, 'W');
594 OP_REQUIRES(
595 context, (stride_n == 1 && stride_c == 1),
596 errors::InvalidArgument("Current implementation does not yet support "
597 "strides in the batch and depth dimensions."));
598 OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
599 errors::InvalidArgument(
600 "Row and column strides should be larger than 0."));
601 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
602 OP_REQUIRES(context, dilations_.size() == 4,
603 errors::InvalidArgument("Sliding window dilations field must "
604 "specify 4 dimensions"));
605 int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
606 int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
607 int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
608 int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
609 OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
610 errors::InvalidArgument(
611 "Current implementation does not yet support "
612 "dilations in the batch and depth dimensions."));
613 OP_REQUIRES(
614 context, dilation_h > 0 && dilation_w > 0,
615 errors::InvalidArgument("Dilated rates should be larger than 0."));
616 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
617 use_cudnn_ &= CanUseCudnn();
618 cudnn_use_autotune_ = CudnnUseAutotune();
619 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
620 if (!std::is_same<Device, GPUDevice>::value) {
621 OP_REQUIRES(
622 context, padding_ != Padding::EXPLICIT,
623 errors::Unimplemented("Current CPU implementation does not support "
624 "EXPLICIT padding yet."));
625 }
626 OP_REQUIRES_OK(context,
627 context->GetAttr("explicit_paddings", &explicit_paddings_));
628 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
629 /*num_dims=*/4, data_format_));
630 }
631
Compute(OpKernelContext * context)632 void Compute(OpKernelContext* context) override {
633 const Tensor& input_sizes = context->input(0);
634 const Tensor& filter = context->input(1);
635 const Tensor& out_backprop = context->input(2);
636 OP_REQUIRES(
637 context, TensorShapeUtils::IsVector(input_sizes.shape()),
638 errors::InvalidArgument(
639 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
640 input_sizes.dims()));
641 TensorShape input_shape;
642 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
643 input_sizes.vec<int32>(), &input_shape));
644
645 Tensor* in_backprop = nullptr;
646 OP_REQUIRES_OK(context,
647 context->allocate_output(0, input_shape, &in_backprop));
648
649 // If there is nothing to compute, return.
650 if (input_shape.num_elements() == 0) {
651 return;
652 }
653
654 // For now we take the stride from the second and third dimensions only (we
655 // do not support striding on the batch or depth dimension).
656 const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
657 const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
658 const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
659 const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
660
661 launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
662 dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
663 explicit_paddings_, in_backprop, data_format_);
664 }
665
666 private:
667 std::vector<int32> dilations_;
668 std::vector<int32> strides_;
669 Padding padding_;
670 std::vector<int64> explicit_paddings_;
671 bool use_cudnn_;
672 TensorFormat data_format_;
673 LaunchConv2DBackpropInputOp<Device, T> launcher_;
674 bool cudnn_use_autotune_;
675
676 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
677 };
678
679 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & out_backprop,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * in_backprop,TensorFormat data_format)680 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
681 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
682 const Tensor& out_backprop, const Tensor& filter, int row_dilation,
683 int col_dilation, int row_stride, int col_stride, const Padding& padding,
684 const std::vector<int64>& explicit_paddings, Tensor* in_backprop,
685 TensorFormat data_format) {
686 using se::dnn::AlgorithmConfig;
687 using se::dnn::AlgorithmDesc;
688 using se::dnn::ProfileResult;
689
690 std::vector<int32> strides(4, 1);
691 std::vector<int32> dilations(4, 1);
692 auto input_h = GetTensorDimIndex(data_format, 'H');
693 auto input_w = GetTensorDimIndex(data_format, 'W');
694 strides[input_h] = row_stride;
695 strides[input_w] = col_stride;
696 dilations[input_h] = row_dilation;
697 dilations[input_w] = col_dilation;
698 TensorShape input_shape = in_backprop->shape();
699
700 const TensorShape& filter_shape = filter.shape();
701 ConvBackpropDimensions dims;
702 OP_REQUIRES_OK(
703 ctx, ConvBackpropComputeDimensionsV2(
704 "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape,
705 filter_shape, out_backprop.shape(), dilations, strides, padding,
706 explicit_paddings, data_format, &dims));
707
708 int64 padding_top = -1, padding_bottom = -1;
709 int64 padding_left = -1, padding_right = -1;
710 if (padding == EXPLICIT) {
711 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
712 &padding_bottom);
713 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
714 &padding_right);
715 }
716 int64 expected_out_rows, expected_out_cols;
717 // The function is guaranteed to succeed because we checked the output and
718 // padding was valid earlier.
719 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
720 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
721 row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
722 &padding_bottom));
723 DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
724 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
725 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
726 col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
727 &padding_right));
728 DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
729
730 auto* stream = ctx->op_device_context()->stream();
731 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
732
733 if (!use_cudnn) {
734 ctx->SetStatus(errors::Unimplemented(
735 "Conv2DBackpropInput for GPU is not currently supported "
736 "without cudnn"));
737 return;
738 }
739
740 // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
741 // input depth, it's a depthwise convolution. More generally, if the filter
742 // in-depth divides but is smaller than the input depth, it is a grouped
743 // convolution.
744 bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
745 if (dims.spatial_dims[0].filter_size == 1 &&
746 dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
747 dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
748 data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
749 // 1x1 filter, so call cublas directly.
750 const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
751 dims.spatial_dims[1].input_size;
752 const uint64 k = dims.out_depth;
753 const uint64 n = dims.in_depth;
754
755 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
756 out_backprop.template flat<T>().size());
757 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
758 filter.template flat<T>().size());
759 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
760 in_backprop->template flat<T>().size());
761
762 auto transpose = se::blas::Transpose::kTranspose;
763 auto no_transpose = se::blas::Transpose::kNoTranspose;
764
765 bool blas_launch_status =
766 stream
767 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
768 a_ptr, k, 0.0f, &c_ptr, n)
769 .ok();
770 if (!blas_launch_status) {
771 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
772 ", n=", n, ", k=", k));
773 }
774 return;
775 } else if (dims.spatial_dims[0].filter_size ==
776 dims.spatial_dims[0].input_size &&
777 dims.spatial_dims[1].filter_size ==
778 dims.spatial_dims[1].input_size &&
779 !is_grouped_convolution && padding == VALID &&
780 data_format == FORMAT_NHWC) {
781 // The input data and filter have the same height/width, and we are not
782 // using grouped convolution, so call cublas directly.
783 const uint64 m = dims.batch_size;
784 const uint64 k = dims.out_depth;
785 const uint64 n = dims.spatial_dims[0].input_size *
786 dims.spatial_dims[1].input_size * dims.in_depth;
787
788 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
789 out_backprop.template flat<T>().size());
790 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
791 filter.template flat<T>().size());
792 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
793 in_backprop->template flat<T>().size());
794
795 auto transpose = se::blas::Transpose::kTranspose;
796 auto no_transpose = se::blas::Transpose::kNoTranspose;
797
798 bool blas_launch_status =
799 stream
800 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
801 a_ptr, k, 0.0f, &c_ptr, n)
802 .ok();
803 if (!blas_launch_status) {
804 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
805 ", n=", n, ", k=", k));
806 }
807 return;
808 }
809
810 const int64 common_padding_rows = std::min(padding_top, padding_bottom);
811 const int64 common_padding_cols = std::min(padding_left, padding_right);
812 TensorShape compatible_input_shape;
813 if (padding_top != padding_bottom || padding_left != padding_right) {
814 // Pad the input in the same way we did during the forward pass, so that
815 // cuDNN receives the same input during the backward pass function as it did
816 // during the forward pass function.
817 const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
818 const int64 padding_cols_diff = std::abs(padding_right - padding_left);
819 const int64 new_in_rows =
820 dims.spatial_dims[0].input_size + padding_rows_diff;
821 const int64 new_in_cols =
822 dims.spatial_dims[1].input_size + padding_cols_diff;
823 compatible_input_shape = ShapeFromFormat(
824 data_format, dims.batch_size, new_in_rows, new_in_cols, dims.in_depth);
825 } else {
826 compatible_input_shape = input_shape;
827 }
828
829 CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
830 << "Negative row or col paddings: (" << common_padding_rows << ", "
831 << common_padding_cols << ")";
832 se::dnn::BatchDescriptor input_desc;
833 input_desc.set_count(dims.batch_size)
834 .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
835 .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
836 .set_feature_map_count(dims.in_depth)
837 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
838 se::dnn::BatchDescriptor output_desc;
839 output_desc.set_count(dims.batch_size)
840 .set_height(dims.spatial_dims[0].output_size)
841 .set_width(dims.spatial_dims[1].output_size)
842 .set_feature_map_count(dims.out_depth)
843 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
844 se::dnn::FilterDescriptor filter_desc;
845 filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
846 .set_input_filter_width(dims.spatial_dims[1].filter_size)
847 .set_input_feature_map_count(filter_shape.dim_size(2))
848 .set_output_feature_map_count(filter_shape.dim_size(3));
849 se::dnn::ConvolutionDescriptor conv_desc;
850 conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
851 .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
852 .set_vertical_filter_stride(dims.spatial_dims[0].stride)
853 .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
854 .set_zero_padding_height(common_padding_rows)
855 .set_zero_padding_width(common_padding_cols)
856 .set_group_count(dims.in_depth / filter_shape.dim_size(2));
857
858 // NOTE(keveman):
859 // cuDNN only supports the following layouts :
860 // Input : B x D x R x C
861 // Filter : OD x ID x R x C
862 // Whereas, we have
863 // Input : B x R x C x D
864 // Filter : R x C x ID x OD
865 // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
866 // The first TransformDepth performs
867 // (B x R x C x D) => (B x D x R x C).
868 // Since the tensor returned from cuDNN is B x D x R x C also,
869 // the second TransformDepth performs
870 // (B x D x R x C) => (B x R x C x D).
871 Tensor transformed_filter;
872 OP_REQUIRES_OK(
873 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
874 TensorShape({dims.out_depth, dims.in_depth,
875 dims.spatial_dims[0].filter_size,
876 dims.spatial_dims[1].filter_size}),
877 &transformed_filter));
878
879 functor::TransformFilter<GPUDevice, T, int, 4>()(
880 ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
881 To32Bit(filter.tensor<T, 4>()),
882 To32Bit(transformed_filter.tensor<T, 4>()));
883
884 Tensor transformed_out_backprop;
885 if (data_format == FORMAT_NHWC) {
886 TensorShape nchw_shape = ShapeFromFormat(
887 FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
888 dims.spatial_dims[1].output_size, dims.out_depth);
889 if (dims.out_depth > 1) {
890 OP_REQUIRES_OK(ctx,
891 ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
892 &transformed_out_backprop));
893 functor::NHWCToNCHW<GPUDevice, T, 4>()(
894 ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
895 transformed_out_backprop.tensor<T, 4>());
896 } else {
897 // If depth <= 1, then just reshape.
898 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
899 }
900 } else {
901 transformed_out_backprop = out_backprop;
902 }
903
904 Tensor pre_transformed_in_backprop;
905 OP_REQUIRES_OK(
906 ctx, ctx->allocate_temp(
907 DataTypeToEnum<T>::value,
908 ShapeFromFormat(
909 FORMAT_NCHW,
910 GetTensorDim(compatible_input_shape, data_format, 'N'),
911 GetTensorDim(compatible_input_shape, data_format, 'H'),
912 GetTensorDim(compatible_input_shape, data_format, 'W'),
913 GetTensorDim(compatible_input_shape, data_format, 'C')),
914 &pre_transformed_in_backprop));
915
916 auto out_backprop_ptr =
917 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
918 transformed_out_backprop.template flat<T>().size());
919 auto filter_ptr =
920 AsDeviceMemory(transformed_filter.template flat<T>().data(),
921 transformed_filter.template flat<T>().size());
922 auto in_backprop_ptr =
923 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
924 pre_transformed_in_backprop.template flat<T>().size());
925
926 static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
927 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
928 );
929 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
930 int device_id = stream->parent()->device_ordinal();
931 DataType dtype = out_backprop.dtype();
932 ConvParameters conv_parameters = {
933 dims.batch_size, // batch
934 dims.in_depth, // in_depths
935 {{input_desc.height(), // in_rows
936 input_desc.width()}}, // in_cols
937 FORMAT_NCHW, // compute_data_format
938 dims.out_depth, // out_depths
939 {{dims.spatial_dims[0].filter_size, // filter_rows
940 dims.spatial_dims[1].filter_size, // filter_cols
941 filter_shape.dim_size(2)}}, // filter_depths
942 {{dims.spatial_dims[0].dilation, // dilation_rows
943 dims.spatial_dims[1].dilation}}, // dilation_cols
944 {{dims.spatial_dims[0].stride, // stride_rows
945 dims.spatial_dims[1].stride}}, // stride_cols
946 {{common_padding_rows, // padding_rows
947 common_padding_cols}}, // padding_cols
948 dtype, // tensor data type
949 device_id, // device_id
950 };
951 AlgorithmConfig algorithm_config;
952 if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
953 conv_parameters, &algorithm_config)) {
954 std::vector<AlgorithmDesc> algorithms;
955 CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
956 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
957 &algorithms));
958 std::vector<tensorflow::AutotuneResult> results;
959 for (auto profile_algorithm : algorithms) {
960 // TODO(zhengxq): profile each algorithm multiple times to better
961 // accuracy.
962 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
963 ctx);
964 ProfileResult profile_result;
965 bool cudnn_launch_status =
966 stream
967 ->ThenConvolveBackwardDataWithAlgorithm(
968 filter_desc, filter_ptr, output_desc, out_backprop_ptr,
969 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
970 AlgorithmConfig(profile_algorithm), &profile_result)
971 .ok();
972 if (cudnn_launch_status) {
973 if (profile_result.is_valid()) {
974 results.emplace_back();
975 auto& result = results.back();
976 result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
977 result.mutable_conv()->set_tensor_ops_enabled(
978 profile_algorithm.tensor_ops_enabled());
979 result.mutable_success()->set_scratch_bytes(
980 scratch_allocator.TotalByteSize());
981 *result.mutable_success()->mutable_run_time() =
982 proto_utils::ToDurationProto(
983 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
984 }
985 }
986 }
987 LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop,
988 transformed_filter, transformed_out_backprop,
989 stream->parent(), results);
990 OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
991 AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
992 algorithm_config);
993 }
994 bool cudnn_launch_status =
995 stream
996 ->ThenConvolveBackwardDataWithAlgorithm(
997 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
998 input_desc, &in_backprop_ptr, &scratch_allocator,
999 algorithm_config, nullptr)
1000 .ok();
1001
1002 if (!cudnn_launch_status) {
1003 ctx->SetStatus(errors::Internal(
1004 "cuDNN Backward Data function launch failure : input shape(",
1005 input_shape.DebugString(), ") filter shape(",
1006 filter_shape.DebugString(), ")"));
1007 return;
1008 }
1009
1010 if (padding_top != padding_bottom || padding_left != padding_right) {
1011 Tensor in_backprop_remove_padding;
1012 OP_REQUIRES_OK(
1013 ctx, ctx->allocate_temp(
1014 DataTypeToEnum<T>::value,
1015 ShapeFromFormat(FORMAT_NCHW,
1016 GetTensorDim(input_shape, data_format, 'N'),
1017 GetTensorDim(input_shape, data_format, 'H'),
1018 GetTensorDim(input_shape, data_format, 'W'),
1019 GetTensorDim(input_shape, data_format, 'C')),
1020 &in_backprop_remove_padding));
1021
1022 // Remove the padding that was added to the input shape above.
1023 const int64 input_pad_top = padding_top - common_padding_rows;
1024 const int64 input_pad_bottom = padding_bottom - common_padding_rows;
1025 const int64 input_pad_left = padding_left - common_padding_cols;
1026 const int64 input_pad_right = padding_right - common_padding_cols;
1027 functor::PadInput<GPUDevice, T, int, 4>()(
1028 ctx->template eigen_device<GPUDevice>(),
1029 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
1030 .tensor<T, 4>()),
1031 {{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
1032 {{static_cast<int>(-input_pad_bottom),
1033 static_cast<int>(-input_pad_right)}},
1034 To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
1035
1036 pre_transformed_in_backprop = in_backprop_remove_padding;
1037 }
1038
1039 if (data_format == FORMAT_NHWC) {
1040 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
1041 functor::NCHWToNHWC<GPUDevice, T, 4>()(
1042 ctx->eigen_device<GPUDevice>(),
1043 toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
1044 in_backprop->tensor<T, 4>());
1045 } else {
1046 *in_backprop = pre_transformed_in_backprop;
1047 }
1048 }
1049
1050 // Forward declarations of the functor specializations for GPU.
1051 namespace functor {
1052 #define DECLARE_GPU_SPEC(T) \
1053 template <> \
1054 void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
1055 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
1056 const Eigen::DSizes<int, 4>& order, \
1057 const Eigen::array<bool, 4>& reverse_dims, \
1058 typename TTypes<T, 4, int>::Tensor output); \
1059 extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
1060 template <> \
1061 void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
1062 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
1063 const Eigen::DSizes<int, 4>& strides, \
1064 const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
1065 const Eigen::DSizes<int, 4>& order, \
1066 typename TTypes<T, 4, int>::Tensor output); \
1067 extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
1068 template <> \
1069 void TransformFilter<GPUDevice, T, int, 4>::operator()( \
1070 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
1071 typename TTypes<T, 4, int>::ConstTensor in, \
1072 typename TTypes<T, 4, int>::Tensor out); \
1073 extern template struct TransformFilter<GPUDevice, T, int, 4>; \
1074 template <> \
1075 void TransformDepth<GPUDevice, T, int>::operator()( \
1076 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
1077 const Eigen::DSizes<int, 4>& shuffle, \
1078 typename TTypes<T, 4, int>::Tensor out); \
1079 extern template struct TransformDepth<GPUDevice, T, int>; \
1080 template <> \
1081 void PadInput<GPUDevice, T, int, 4>::operator()( \
1082 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
1083 const std::array<int, 2>& padding_left, \
1084 const std::array<int, 2>& padding_right, \
1085 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
1086 extern template struct PadInput<GPUDevice, T, int, 4>;
1087
1088 DECLARE_GPU_SPEC(float);
1089 DECLARE_GPU_SPEC(Eigen::half);
1090 DECLARE_GPU_SPEC(double);
1091 #undef DECLARE_GPU_SPEC
1092 } // namespace functor
1093
1094 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1095 .Device(DEVICE_GPU)
1096 .TypeConstraint<double>("T")
1097 .HostMemory("input_sizes"),
1098 Conv2DSlowBackpropInputOp<GPUDevice, double>);
1099 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1100 .Device(DEVICE_GPU)
1101 .TypeConstraint<float>("T")
1102 .HostMemory("input_sizes"),
1103 Conv2DSlowBackpropInputOp<GPUDevice, float>);
1104 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
1105 .Device(DEVICE_GPU)
1106 .TypeConstraint<Eigen::half>("T")
1107 .HostMemory("input_sizes"),
1108 Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
1109
1110 // To be used inside depthwise_conv_grad_op.cc.
1111 // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
1112 template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
1113 template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
1114 template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
1115
1116 #endif // GOOGLE_CUDA
1117
1118 } // namespace tensorflow
1119