• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 namespace Eigen {
22 
23 /** SpatialConvolutionBackwardInput
24  * \ingroup CXX11_NeuralNetworks_Module
25  *
26  * \brief Computes the backprop for the input of a 2D convolution.
27  *
28  * The output_backward parameter is expected to be a tensor with a rank of 3 or
29  * more (channels, height, width, and optionally others)
30  * The kernel parameter is expected to be a 4D tensor (filters, channels,
31  * kernel_height, kernel_width)
32  * The output_backward and the kernel must both be in col-major layout. The
33  * result will also be in col-major layout.
34  *
35  * If row_in_stride, col_in_stride > 1, then applies convolution with holes
36  * (aka atrous convolution), sampling every row_in_stride, col_in_stride input
37  * pixels.
38  *
39  * The result can be assigned to a tensor of rank equal to the rank of the
40  * output_backward. The dimensions of the result will be filters, height, width
41  * (and others if applicable).
42  *
43  * It is possible to swap the order of the width and height dimensions provided
44  * that the same order is used in the input, the kernel, and the output.
45  *
46  */
47 #ifdef EIGEN_HAS_INDEX_LIST
48 typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>>
49     ReverseColMajor;
50 typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>>
51     ReverseRowMajor;
52 #else
53 typedef array<bool, 4> ReverseColMajor;
54 typedef array<bool, 4> ReverseRowMajor;
55 #endif
56 
57 template <typename OutputBackward, typename Kernel>
58 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
59     internal::traits<OutputBackward>::Layout == ColMajor,
60     TensorReshapingOp<
61         const DSizes<typename internal::traits<OutputBackward>::Index,
62                      internal::traits<OutputBackward>::NumDimensions>,
63         const TensorContractionOp<
64             const array<
65                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
66             const TensorReshapingOp<
67                 const DSizes<typename internal::traits<OutputBackward>::Index,
68                              2>,
69                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
70                     const array<
71                         typename internal::traits<OutputBackward>::Index, 4>,
72                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
73                         const ReverseColMajor, const Kernel>>>>>,
74             const TensorReshapingOp<
75                 const DSizes<typename internal::traits<OutputBackward>::Index,
76                              2>,
77                 const TensorImagePatchOp<Dynamic, Dynamic,
78                                          const OutputBackward>>>>,
79     TensorReshapingOp<
80 
81         const DSizes<typename internal::traits<OutputBackward>::Index,
82                      internal::traits<OutputBackward>::NumDimensions>,
83         const TensorContractionOp<
84             const array<
85                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
86             const TensorReshapingOp<
87                 const DSizes<typename internal::traits<OutputBackward>::Index,
88                              2>,
89                 const TensorImagePatchOp<Dynamic, Dynamic,
90                                          const OutputBackward>>,
91             const TensorReshapingOp<
92                 const DSizes<typename internal::traits<OutputBackward>::Index,
93                              2>,
94                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
95                     const array<
96                         typename internal::traits<OutputBackward>::Index, 4>,
97                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
98                         const ReverseRowMajor, const Kernel>>>>>>>>::type
99 SpatialConvolutionBackwardInput(
100     const Kernel& kernel, const OutputBackward& output_backward,
101     typename internal::traits<OutputBackward>::Index inputRows,
102     typename internal::traits<OutputBackward>::Index inputCols,
103     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
104     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
105   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
106   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
107   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
108                    internal::traits<Kernel>::NumDimensions,
109                    internal::traits<Kernel>::Layout, TensorIndex>>
110       kern(kernel);
111   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
112                    internal::traits<OutputBackward>::Layout, TensorIndex>>
113       out(output_backward);
114 
115   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
116                           internal::traits<OutputBackward>::Layout,
117                       YOU_MADE_A_PROGRAMMING_MISTAKE);
118 
119   static const bool isColMajor =
120       (internal::traits<OutputBackward>::Layout == ColMajor);
121 
122   static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
123 
124   // Number of filters to apply. This is the same as the output depth of the
125   // result
126   const TensorIndex kernelFilters =
127       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
128   // Number of channels. This is the same as the input depth.
129   const TensorIndex kernelChannels =
130       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
131   const TensorIndex kernelRows =
132       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
133   const TensorIndex kernelCols =
134       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
135 
136   // This is the effective kernel size, taking into account the (*_in_stride -
137   // 1) zero-values
138   // inserted between consecutive kernel elements in atrous convolution
139   const TensorIndex kernelRowsEff =
140       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
141   const TensorIndex kernelColsEff =
142       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
143 
144   const TensorIndex outputRows = isColMajor
145                                      ? output_backward.dimension(1)
146                                      : output_backward.dimension(NumDims - 2);
147   const TensorIndex outputCols = isColMajor
148                                      ? output_backward.dimension(2)
149                                      : output_backward.dimension(NumDims - 3);
150 
151   // Computing the forward padding
152   const TensorIndex forward_pad_top = numext::maxi<Index>(
153       0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
154   const TensorIndex forward_pad_left = numext::maxi<Index>(
155       0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
156   const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
157   const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
158 
159   const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
160                                      2 - padding_top + kernelRowsEff;
161   const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
162                                     2 - padding_left + kernelColsEff;
163 
164   eigen_assert(padding_top >= 0);
165   eigen_assert(padding_left >= 0);
166   eigen_assert(padding_bottom >= 0);
167   eigen_assert(padding_right >= 0);
168 
169   // The kernel has dimensions filters X channels X patch_rows X patch_cols
170   // We need to reverse the kernel along dimensions corresponding to rows and
171   // cols.
172   // TODO(yangke): we can make things slightly faster by collapsing the
173   // dimensions
174   // where we don't reverse. Try that once we have a faster compiler.
175   typedef typename internal::conditional<isColMajor, ReverseColMajor,
176                                          ReverseRowMajor>::type Reverse;
177   Reverse kernel_reverse;
178 
179 #ifndef EIGEN_HAS_INDEX_LIST
180   if (isColMajor) {
181     kernel_reverse[0] = false;
182     kernel_reverse[1] = false;
183     kernel_reverse[2] = true;
184     kernel_reverse[3] = true;
185   } else {
186     kernel_reverse[0] = true;
187     kernel_reverse[1] = true;
188     kernel_reverse[2] = false;
189     kernel_reverse[3] = false;
190   }
191 #endif
192 
193   // Reorder the dimensions to:
194   //   filters x patch_rows x patch_cols x channels
195   array<TensorIndex, 4> kernel_shuffle;
196   if (isColMajor) {
197     //  From: filters x channels x rows x cols
198     //  To:   filters x rows x cols x channels
199     kernel_shuffle[0] = 0;
200     kernel_shuffle[1] = 2;
201     kernel_shuffle[2] = 3;
202     kernel_shuffle[3] = 1;
203   } else {
204     //  From: cols x rows x channels x filters
205     //  To:   channels x cols x rows x filters
206     kernel_shuffle[0] = 2;
207     kernel_shuffle[1] = 0;
208     kernel_shuffle[2] = 1;
209     kernel_shuffle[3] = 3;
210   }
211 
212   // Collapse the dims
213   DSizes<TensorIndex, 2> kernel_dims;
214   if (isColMajor) {
215     kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
216     kernel_dims[1] = kernelChannels;
217   } else {
218     kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
219     kernel_dims[0] = kernelChannels;
220   }
221 
222   // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
223   // When we extract the image patches from output_backward, it will have
224   // dimensions
225   //   out_depth X (patch_rows * patch_cols) X (input_rows * input_cols *
226   //   OTHERS)
227   DSizes<TensorIndex, 2> pre_contract_dims;
228   if (isColMajor) {
229     pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
230     pre_contract_dims[1] = inputRows * inputCols;
231     for (int i = 3; i < NumDims; ++i) {
232       pre_contract_dims[1] *= out.dimension(i);
233     }
234   } else {
235     pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
236     pre_contract_dims[0] = inputRows * inputCols;
237     for (int i = 0; i < NumDims - 3; ++i) {
238       pre_contract_dims[0] *= out.dimension(i);
239     }
240   }
241 
242   // We will contract along the collapsed dimension that contains the
243   // kernelFilters, the kernelRows and the kernelCols.
244   array<IndexPair<TensorIndex>, 1> contract_dims;
245   if (isColMajor) {
246     // col-major: kernel.contract(output.patches)
247     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
248   } else {
249     // row-major: output.patches.contract(kernel)
250     contract_dims[0] = IndexPair<TensorIndex>(1, 1);
251   }
252 
253   // Post contraction, the dimensions of the input_backprop is
254   //  channels X input_rows X input_cols X OTHERS
255   DSizes<TensorIndex, NumDims> post_contract_dims;
256   if (isColMajor) {
257     post_contract_dims[0] = kernelChannels;
258     post_contract_dims[1] = inputRows;
259     post_contract_dims[2] = inputCols;
260     for (int i = 3; i < NumDims; ++i) {
261       post_contract_dims[i] = out.dimension(i);
262     }
263   } else {
264     post_contract_dims[NumDims - 1] = kernelChannels;
265     post_contract_dims[NumDims - 2] = inputRows;
266     post_contract_dims[NumDims - 3] = inputCols;
267     for (int i = 0; i < NumDims - 3; ++i) {
268       post_contract_dims[i] = out.dimension(i);
269     }
270   }
271 
272   // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled
273   // evaluation of these ops does not compose. Doing explicit eval is ~8x
274   // faster in micro benchmarks.
275 
276   return choose(
277       Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
278       kernel.reverse(kernel_reverse)
279           .eval()
280           .shuffle(kernel_shuffle)
281           .eval()
282           .reshape(kernel_dims)
283           .contract(
284               output_backward
285                   .extract_image_patches(
286                       kernelRows, kernelCols, 1, 1, row_in_stride,
287                       col_in_stride, row_stride, col_stride, padding_top,
288                       padding_bottom, padding_left, padding_right, OutScalar(0))
289                   .reshape(pre_contract_dims),
290               contract_dims)
291           .reshape(post_contract_dims),
292       output_backward
293           .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride,
294                                  col_in_stride, row_stride, col_stride,
295                                  padding_top, padding_bottom, padding_left,
296                                  padding_right, OutScalar(0))
297           .reshape(pre_contract_dims)
298           .contract(kernel.reverse(kernel_reverse)
299                         .eval()
300                         .shuffle(kernel_shuffle)
301                         .eval()
302                         .reshape(kernel_dims),
303                     contract_dims)
304           .reshape(post_contract_dims));
305 }
306 
307 /** SpatialConvolutionBackwardKernel
308  * \ingroup CXX11_NeuralNetworks_Module
309  *
310  * \brief Computes the backprop for the filter of a 2D convolution.
311  *
312  * The output_backward parameter is expected to be a tensor with a rank of 3 or
313  * more (channels, height, width, and optionally others)
314  * The kernel parameter is expected to be a 4D tensor (filters, channels,
315  * kernel_height, kernel_width)
316  * The output_backward and the kernel must both be in col-major layout. The
317  * result will also be in col-major layout.
318  *
319  * If row_in_stride, col_stride > 1, then applies convolution with holes (aka
320  * atrous convolution), sampling every row_in_stride, col_in_stride input
321  * pixels.
322  *
323  * The result can be assigned to a tensor of rank equal to the rank of the
324  * output_backward. The dimensions of the result will be filters, height, width
325  * (and others if applicable).
326  *
327  * It is possible to swap the order of the width and height dimensions provided
328  * that the same order is used in the input, the kernel, and the output.
329  *
330  */
331 
332 template <typename OutputBackward, typename Input>
333 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
334     internal::traits<Input>::Layout == ColMajor,
335     const TensorReverseOp<
336         const Eigen::array<typename internal::traits<Input>::Index,
337                            internal::traits<Input>::NumDimensions>,
338         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
339             const Eigen::array<typename internal::traits<Input>::Index,
340                                internal::traits<Input>::NumDimensions>,
341             const Eigen::TensorReshapingOp<
342                 const Eigen::DSizes<typename internal::traits<Input>::Index,
343                                     internal::traits<Input>::NumDimensions>,
344                 const TensorContractionOp<
345                     const array<
346                         IndexPair<typename internal::traits<Input>::Index>, 1>,
347                     const TensorReshapingOp<
348                         const DSizes<typename internal::traits<Input>::Index,
349                                      2>,
350                         const Eigen::TensorForcedEvalOp<
351                             const Eigen::TensorShufflingOp<
352                                 const Eigen::array<
353                                     typename internal::traits<Input>::Index,
354                                     internal::traits<Input>::NumDimensions>,
355                                 const Input>>>,
356                     const TensorReshapingOp<
357                         const DSizes<typename internal::traits<Input>::Index,
358                                      2>,
359                         const TensorImagePatchOp<
360                             Dynamic, Dynamic,
361                             const Eigen::TensorForcedEvalOp<
362                                 const Eigen::TensorShufflingOp<
363                                     const Eigen::array<
364                                         typename internal::traits<Input>::Index,
365                                         internal::traits<Input>::NumDimensions>,
366                                     const OutputBackward>>>>>>>>>,
367     const TensorReverseOp<
368         const Eigen::array<typename internal::traits<Input>::Index,
369                            internal::traits<Input>::NumDimensions>,
370         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
371             const Eigen::array<typename internal::traits<Input>::Index,
372                                internal::traits<Input>::NumDimensions>,
373             const Eigen::TensorReshapingOp<
374                 const Eigen::DSizes<typename internal::traits<Input>::Index,
375                                     internal::traits<Input>::NumDimensions>,
376                 const TensorContractionOp<
377                     const array<
378                         IndexPair<typename internal::traits<Input>::Index>, 1>,
379                     const TensorReshapingOp<
380                         const DSizes<typename internal::traits<Input>::Index,
381                                      2>,
382                         const TensorImagePatchOp<
383                             Dynamic, Dynamic,
384                             const Eigen::TensorForcedEvalOp<
385                                 const Eigen::TensorShufflingOp<
386                                     const Eigen::array<
387                                         typename internal::traits<Input>::Index,
388                                         internal::traits<Input>::NumDimensions>,
389                                     const OutputBackward>>>>,
390                     const TensorReshapingOp<
391                         const DSizes<typename internal::traits<Input>::Index,
392                                      2>,
393                         const Eigen::TensorForcedEvalOp<
394                             const Eigen::TensorShufflingOp<
395                                 const Eigen::array<
396                                     typename internal::traits<Input>::Index,
397                                     internal::traits<Input>::NumDimensions>,
398                                 const Input>>>>>>>>>::type
399 SpatialConvolutionBackwardKernel(
400     const Input& input, const OutputBackward& output_backward,
401     typename internal::traits<Input>::Index kernelRows,
402     typename internal::traits<Input>::Index kernelCols,
403     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
404     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
405   typedef typename internal::traits<Input>::Index TensorIndex;
406   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
407   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
408                    internal::traits<Input>::NumDimensions,
409                    internal::traits<Input>::Layout, TensorIndex>>
410       in(input);
411   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
412                    internal::traits<OutputBackward>::Layout, TensorIndex>>
413       out(output_backward);
414 
415   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
416                           internal::traits<OutputBackward>::Layout,
417                       YOU_MADE_A_PROGRAMMING_MISTAKE);
418 
419   // stride and in_stride cannot both be larger than 1
420   eigen_assert(!(row_stride > 1 && row_in_stride > 1));
421   eigen_assert(!(col_stride > 1 && col_in_stride > 1));
422 
423   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
424 
425   static const int NumDims = internal::traits<Input>::NumDimensions;
426   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
427                           internal::traits<OutputBackward>::NumDimensions,
428                       YOU_MADE_A_PROGRAMMING_MISTAKE);
429   EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
430 
431   const TensorIndex inputRows =
432       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
433   const TensorIndex inputCols =
434       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
435 
436   const TensorIndex outputRows = isColMajor
437                                      ? output_backward.dimension(1)
438                                      : output_backward.dimension(NumDims - 2);
439   const TensorIndex outputCols = isColMajor
440                                      ? output_backward.dimension(2)
441                                      : output_backward.dimension(NumDims - 3);
442 
443   // Number of filters to apply. This is the same as the output depth of the
444   // result
445   const TensorIndex kernelFilters =
446       isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
447 
448   // Number of channels. This is the same as the input depth.
449   const TensorIndex kernelChannels =
450       isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
451 
452   // This is the effective kernel size, taking into account the
453   // (*_in_stride - 1) zero-values inserted between consecutive kernel
454   // elements in atrous convolution
455   const TensorIndex kernelRowsEff =
456       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
457   const TensorIndex kernelColsEff =
458       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
459 
460   // Number of batches (and other dimensions) in the input tensor.
461   TensorIndex batch = 1;
462   for (int d = 3; d < NumDims; ++d) {
463     batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1);
464   }
465 
466   // Computing the forward padding
467   const TensorIndex padRows = numext::maxi<Index>(
468       0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows);
469   const TensorIndex padCols = numext::maxi<Index>(
470       0, (outputCols - 1) * col_stride + kernelColsEff - inputCols);
471 
472   TensorIndex padding_top = padRows / 2;
473   TensorIndex padding_left = padCols / 2;
474 
475   // Compute paddings for output_backward before extracting patches.
476   const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1;
477   const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1;
478 
479   const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1;
480   const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1;
481 
482   const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top;
483   const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left;
484 
485   const TensorIndex bottom_pad_rows =
486       padded_out_rows - expanded_out_rows - top_pad_rows;
487   const TensorIndex right_pad_cols =
488       padded_out_cols - expanded_out_cols - left_pad_cols;
489 
490   // Reorder output_backward dimensions.
491   array<TensorIndex, 4> output_backward_shuffle;
492   if (isColMajor) {
493     // From: [out_depth, out_rows, out_cols, batch]
494     // To:   [batch, out_rows, out_cols, out_depth]
495     output_backward_shuffle = {3, 1, 2, 0};
496   } else {
497     // From: [batch, out_cols, out_rows, out_depth]
498     // To:   [out_depth, out_cols, out_rows, batch]
499     output_backward_shuffle = {3, 1, 2, 0};
500   }
501 
502   // Reorder input dimensions.
503   array<TensorIndex, 4> input_shuffle;
504   if (isColMajor) {
505     // From: [in_depth, in_rows, in_cols, batch]
506     // To:   [in_depth, batch, in_rows, in_cols]
507     input_shuffle = {0, 3, 1, 2};
508   } else {
509     // From: [batch, in_cols, in_rows, in_depth]
510     // To:   [in_cols, in_rows, batch, in_depth]
511     input_shuffle = {1, 2, 0, 3};
512   }
513 
514   // Input is playing the role of a "kernel" in this convolution.
515   DSizes<TensorIndex, 2> input_dims;
516   if (isColMajor) {
517     input_dims[0] = kernelChannels;
518     input_dims[1] = batch * inputRows * inputCols;
519   } else {
520     input_dims[1] = kernelChannels;
521     input_dims[0] = inputCols * inputRows * batch;
522   }
523 
524   // Molds the output of the patch extraction result into a 2D tensor:
525   // - the first dimension (dims[0]): the patch values to be multiplied with the
526   // kernels
527   // - the second dimension (dims[1]): everything else
528   DSizes<TensorIndex, 2> pre_contract_dims;
529   if (isColMajor) {
530     pre_contract_dims[0] = batch * inputRows * inputCols;
531     pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters;
532   } else {
533     pre_contract_dims[1] = inputCols * inputRows * batch;
534     pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows;
535   }
536 
537   // We will contract along the collapsed dimension that contains the
538   // batch, inputRows and inputCols.
539   array<IndexPair<TensorIndex>, 1> contract_dims;
540   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
541 
542   // Dimensions after contraction.
543   DSizes<TensorIndex, NumDims> post_contract_dims;
544   if (isColMajor) {
545     post_contract_dims[0] = kernelChannels;
546     post_contract_dims[1] = kernelRows;
547     post_contract_dims[2] = kernelCols;
548     post_contract_dims[3] = kernelFilters;
549   } else {
550     post_contract_dims[0] = kernelFilters;
551     post_contract_dims[1] = kernelCols;
552     post_contract_dims[2] = kernelRows;
553     post_contract_dims[3] = kernelChannels;
554   }
555 
556   // Reorder output of contraction to a valid filter shape.
557   array<TensorIndex, 4> kernel_shuffle;
558   if (isColMajor) {
559     // From: [in_depth, kernel_rows, kernel_cols, out_depth]
560     // To:   [out_depth, in_depth, kernel_rows, kernel_cols]
561     kernel_shuffle = {3, 0, 1, 2};
562   } else {
563     // From: [out_depth, kernel_cols, kernel_rows, in_depth]
564     // To:   [kernel_cols, kernel_rows, in_depth, out_depth]
565     kernel_shuffle = {1, 2, 3, 0};
566   }
567 
568   // Reverse kernel backprop dimensions.
569   array<TensorIndex, 4> kernel_reverse;
570   if (isColMajor) {
571     kernel_reverse = {false, false, true, true};
572   } else {
573     kernel_reverse = {true, true, false, false};
574   }
575 
576   // Create convolution input (aka source of patches) from output backward
577   // tensor by shuffling dimensions.
578   const auto output_backward_shuffled =
579       output_backward.shuffle(output_backward_shuffle).eval();
580 
581   // Create convolution kernel (aka filter) from input by shuffling and
582   // reshaping.
583   const auto input_shuffled =
584       input.shuffle(input_shuffle).eval().reshape(input_dims);
585 
586   return choose(
587              Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
588              input_shuffled.contract(
589                  output_backward_shuffled
590                      .extract_image_patches(inputRows, inputCols, row_in_stride,
591                                             col_in_stride, 1, 1, row_stride,
592                                             col_stride, top_pad_rows,
593                                             bottom_pad_rows, left_pad_cols,
594                                             right_pad_cols, OutScalar(0))
595                      .reshape(pre_contract_dims),
596                  contract_dims),
597              output_backward_shuffled
598                  .extract_image_patches(
599                      inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
600                      row_stride, col_stride, top_pad_rows, bottom_pad_rows,
601                      left_pad_cols, right_pad_cols, OutScalar(0))
602                  .reshape(pre_contract_dims)
603                  .contract(input_shuffled, contract_dims))
604       .reshape(post_contract_dims)
605       .shuffle(kernel_shuffle)
606       .eval()
607       .reverse(kernel_reverse);
608 }
609 
610 }  // end namespace Eigen
611 
612 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
613