• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
21 
22 namespace Eigen {
23 
24 /** SpatialConvolutionBackwardInput
25  * \ingroup CXX11_NeuralNetworks_Module
26  *
27  * \brief Computes the backprop for the input of a 2D convolution.
28  *
29  * The output_backward parameter is expected to be a tensor with a rank of 3 or
30  * more (channels, height, width, and optionally others)
31  * The kernel parameter is expected to be a 4D tensor (filters, channels,
32  * kernel_height, kernel_width)
33  * The output_backward and the kernel must both be in col-major layout. The
34  * result will also be in col-major layout.
35  *
36  * If row_in_stride, col_in_stride > 1, then applies convolution with holes
37  * (aka atrous convolution), sampling every row_in_stride, col_in_stride input
38  * pixels.
39  *
40  * The result can be assigned to a tensor of rank equal to the rank of the
41  * output_backward. The dimensions of the result will be filters, height, width
42  * (and others if applicable).
43  *
44  * It is possible to swap the order of the width and height dimensions provided
45  * that the same order is used in the input, the kernel, and the output.
46  *
47  */
48 #ifdef EIGEN_HAS_INDEX_LIST
49 typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1>>
50     ReverseColMajor;
51 typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0>>
52     ReverseRowMajor;
53 #else
54 typedef array<bool, 4> ReverseColMajor;
55 typedef array<bool, 4> ReverseRowMajor;
56 #endif
57 
58 template <typename OutputBackward, typename Kernel>
59 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
60     internal::traits<OutputBackward>::Layout == ColMajor,
61     TensorReshapingOp<
62         const DSizes<typename internal::traits<OutputBackward>::Index,
63                      internal::traits<OutputBackward>::NumDimensions>,
64         const TensorContractionOp<
65             const array<
66                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
67             const TensorReshapingOp<
68                 const DSizes<typename internal::traits<OutputBackward>::Index,
69                              2>,
70                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
71                     const array<
72                         typename internal::traits<OutputBackward>::Index, 4>,
73                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
74                         const ReverseColMajor, const Kernel>>>>>,
75             const TensorReshapingOp<
76                 const DSizes<typename internal::traits<OutputBackward>::Index,
77                              2>,
78                 const TensorImagePatchOp<Dynamic, Dynamic,
79                                          const OutputBackward>>>>,
80     TensorReshapingOp<
81 
82         const DSizes<typename internal::traits<OutputBackward>::Index,
83                      internal::traits<OutputBackward>::NumDimensions>,
84         const TensorContractionOp<
85             const array<
86                 IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
87             const TensorReshapingOp<
88                 const DSizes<typename internal::traits<OutputBackward>::Index,
89                              2>,
90                 const TensorImagePatchOp<Dynamic, Dynamic,
91                                          const OutputBackward>>,
92             const TensorReshapingOp<
93                 const DSizes<typename internal::traits<OutputBackward>::Index,
94                              2>,
95                 const Eigen::TensorForcedEvalOp<const TensorShufflingOp<
96                     const array<
97                         typename internal::traits<OutputBackward>::Index, 4>,
98                     const Eigen::TensorForcedEvalOp<const TensorReverseOp<
99                         const ReverseRowMajor, const Kernel>>>>>>>>::type
100 SpatialConvolutionBackwardInput(
101     const Kernel& kernel, const OutputBackward& output_backward,
102     typename internal::traits<OutputBackward>::Index inputRows,
103     typename internal::traits<OutputBackward>::Index inputCols,
104     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
105     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
106   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
107   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
108   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
109                    internal::traits<Kernel>::NumDimensions,
110                    internal::traits<Kernel>::Layout, TensorIndex>>
111       kern(kernel);
112   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
113                    internal::traits<OutputBackward>::Layout, TensorIndex>>
114       out(output_backward);
115 
116   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
117                           internal::traits<OutputBackward>::Layout,
118                       YOU_MADE_A_PROGRAMMING_MISTAKE);
119 
120   static const bool isColMajor =
121       (internal::traits<OutputBackward>::Layout == ColMajor);
122 
123   static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
124 
125   // Number of filters to apply. This is the same as the output depth of the
126   // result
127   const TensorIndex kernelFilters =
128       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
129   // Number of channels. This is the same as the input depth.
130   const TensorIndex kernelChannels =
131       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
132   const TensorIndex kernelRows =
133       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
134   const TensorIndex kernelCols =
135       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
136 
137   // This is the effective kernel size, taking into account the (*_in_stride -
138   // 1) zero-values
139   // inserted between consecutive kernel elements in atrous convolution
140   const TensorIndex kernelRowsEff =
141       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
142   const TensorIndex kernelColsEff =
143       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
144 
145   const TensorIndex outputRows = isColMajor
146                                      ? output_backward.dimension(1)
147                                      : output_backward.dimension(NumDims - 2);
148   const TensorIndex outputCols = isColMajor
149                                      ? output_backward.dimension(2)
150                                      : output_backward.dimension(NumDims - 3);
151 
152   // Computing the forward padding
153   const TensorIndex forward_pad_top = numext::maxi<Index>(
154       0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
155   const TensorIndex forward_pad_left = numext::maxi<Index>(
156       0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
157   const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
158   const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
159 
160   const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
161                                      2 - padding_top + kernelRowsEff;
162   const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
163                                     2 - padding_left + kernelColsEff;
164 
165   eigen_assert(padding_top >= 0);
166   eigen_assert(padding_left >= 0);
167   eigen_assert(padding_bottom >= 0);
168   eigen_assert(padding_right >= 0);
169 
170   // The kernel has dimensions filters X channels X patch_rows X patch_cols
171   // We need to reverse the kernel along dimensions corresponding to rows and
172   // cols.
173   // TODO(yangke): we can make things slightly faster by collapsing the
174   // dimensions
175   // where we don't reverse. Try that once we have a faster compiler.
176   typedef typename internal::conditional<isColMajor, ReverseColMajor,
177                                          ReverseRowMajor>::type Reverse;
178   Reverse kernel_reverse;
179 
180 #ifndef EIGEN_HAS_INDEX_LIST
181   if (isColMajor) {
182     kernel_reverse[0] = false;
183     kernel_reverse[1] = false;
184     kernel_reverse[2] = true;
185     kernel_reverse[3] = true;
186   } else {
187     kernel_reverse[0] = true;
188     kernel_reverse[1] = true;
189     kernel_reverse[2] = false;
190     kernel_reverse[3] = false;
191   }
192 #endif
193 
194   // Reorder the dimensions to:
195   //   filters x patch_rows x patch_cols x channels
196   array<TensorIndex, 4> kernel_shuffle;
197   if (isColMajor) {
198     //  From: filters x channels x rows x cols
199     //  To:   filters x rows x cols x channels
200     kernel_shuffle[0] = 0;
201     kernel_shuffle[1] = 2;
202     kernel_shuffle[2] = 3;
203     kernel_shuffle[3] = 1;
204   } else {
205     //  From: cols x rows x channels x filters
206     //  To:   channels x cols x rows x filters
207     kernel_shuffle[0] = 2;
208     kernel_shuffle[1] = 0;
209     kernel_shuffle[2] = 1;
210     kernel_shuffle[3] = 3;
211   }
212 
213   // Collapse the dims
214   DSizes<TensorIndex, 2> kernel_dims;
215   if (isColMajor) {
216     kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
217     kernel_dims[1] = kernelChannels;
218   } else {
219     kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
220     kernel_dims[0] = kernelChannels;
221   }
222 
223   // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
224   // When we extract the image patches from output_backward, it will have
225   // dimensions
226   //   out_depth X (patch_rows * patch_cols) X (input_rows * input_cols *
227   //   OTHERS)
228   DSizes<TensorIndex, 2> pre_contract_dims;
229   if (isColMajor) {
230     pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
231     pre_contract_dims[1] = inputRows * inputCols;
232     for (int i = 3; i < NumDims; ++i) {
233       pre_contract_dims[1] *= out.dimension(i);
234     }
235   } else {
236     pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
237     pre_contract_dims[0] = inputRows * inputCols;
238     for (int i = 0; i < NumDims - 3; ++i) {
239       pre_contract_dims[0] *= out.dimension(i);
240     }
241   }
242 
243   // We will contract along the collapsed dimension that contains the
244   // kernelFilters, the kernelRows and the kernelCols.
245   array<IndexPair<TensorIndex>, 1> contract_dims;
246   if (isColMajor) {
247     // col-major: kernel.contract(output.patches)
248     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
249   } else {
250     // row-major: output.patches.contract(kernel)
251     contract_dims[0] = IndexPair<TensorIndex>(1, 1);
252   }
253 
254   // Post contraction, the dimensions of the input_backprop is
255   //  channels X input_rows X input_cols X OTHERS
256   DSizes<TensorIndex, NumDims> post_contract_dims;
257   if (isColMajor) {
258     post_contract_dims[0] = kernelChannels;
259     post_contract_dims[1] = inputRows;
260     post_contract_dims[2] = inputCols;
261     for (int i = 3; i < NumDims; ++i) {
262       post_contract_dims[i] = out.dimension(i);
263     }
264   } else {
265     post_contract_dims[NumDims - 1] = kernelChannels;
266     post_contract_dims[NumDims - 2] = inputRows;
267     post_contract_dims[NumDims - 3] = inputCols;
268     for (int i = 0; i < NumDims - 3; ++i) {
269       post_contract_dims[i] = out.dimension(i);
270     }
271   }
272 
273   // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled
274   // evaluation of these ops does not compose. Doing explicit eval is ~8x
275   // faster in micro benchmarks.
276 
277   return choose(
278       Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
279       kernel.reverse(kernel_reverse)
280           .eval()
281           .shuffle(kernel_shuffle)
282           .eval()
283           .reshape(kernel_dims)
284           .contract(
285               output_backward
286                   .extract_image_patches(
287                       kernelRows, kernelCols, 1, 1, row_in_stride,
288                       col_in_stride, row_stride, col_stride, padding_top,
289                       padding_bottom, padding_left, padding_right, OutScalar(0))
290                   .reshape(pre_contract_dims),
291               contract_dims)
292           .reshape(post_contract_dims),
293       output_backward
294           .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride,
295                                  col_in_stride, row_stride, col_stride,
296                                  padding_top, padding_bottom, padding_left,
297                                  padding_right, OutScalar(0))
298           .reshape(pre_contract_dims)
299           .contract(kernel.reverse(kernel_reverse)
300                         .eval()
301                         .shuffle(kernel_shuffle)
302                         .eval()
303                         .reshape(kernel_dims),
304                     contract_dims)
305           .reshape(post_contract_dims));
306 }
307 
308 /** SpatialConvolutionBackwardKernel
309  * \ingroup CXX11_NeuralNetworks_Module
310  *
311  * \brief Computes the backprop for the filter of a 2D convolution.
312  *
313  * The output_backward parameter is expected to be a tensor with a rank of 3 or
314  * more (channels, height, width, and optionally others)
315  * The kernel parameter is expected to be a 4D tensor (filters, channels,
316  * kernel_height, kernel_width)
317  * The output_backward and the kernel must both be in col-major layout. The
318  * result will also be in col-major layout.
319  *
320  * If row_in_stride, col_stride > 1, then applies convolution with holes (aka
321  * atrous convolution), sampling every row_in_stride, col_in_stride input
322  * pixels.
323  *
324  * The result can be assigned to a tensor of rank equal to the rank of the
325  * output_backward. The dimensions of the result will be filters, height, width
326  * (and others if applicable).
327  *
328  * It is possible to swap the order of the width and height dimensions provided
329  * that the same order is used in the input, the kernel, and the output.
330  *
331  */
332 
333 template <typename OutputBackward, typename Input>
334 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
335     internal::traits<Input>::Layout == ColMajor,
336     const TensorReverseOp<
337         const Eigen::array<typename internal::traits<Input>::Index,
338                            internal::traits<Input>::NumDimensions>,
339         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
340             const Eigen::array<typename internal::traits<Input>::Index,
341                                internal::traits<Input>::NumDimensions>,
342             const Eigen::TensorReshapingOp<
343                 const Eigen::DSizes<typename internal::traits<Input>::Index,
344                                     internal::traits<Input>::NumDimensions>,
345                 const TensorContractionOp<
346                     const array<
347                         IndexPair<typename internal::traits<Input>::Index>, 1>,
348                     const TensorReshapingOp<
349                         const DSizes<typename internal::traits<Input>::Index,
350                                      2>,
351                         const Eigen::TensorForcedEvalOp<
352                             const Eigen::TensorShufflingOp<
353                                 const Eigen::array<
354                                     typename internal::traits<Input>::Index,
355                                     internal::traits<Input>::NumDimensions>,
356                                 const Input>>>,
357                     const TensorReshapingOp<
358                         const DSizes<typename internal::traits<Input>::Index,
359                                      2>,
360                         const TensorImagePatchOp<
361                             Dynamic, Dynamic,
362                             const Eigen::TensorForcedEvalOp<
363                                 const Eigen::TensorShufflingOp<
364                                     const Eigen::array<
365                                         typename internal::traits<Input>::Index,
366                                         internal::traits<Input>::NumDimensions>,
367                                     const OutputBackward>>>>>>>>>,
368     const TensorReverseOp<
369         const Eigen::array<typename internal::traits<Input>::Index,
370                            internal::traits<Input>::NumDimensions>,
371         const Eigen::TensorForcedEvalOp<const Eigen::TensorShufflingOp<
372             const Eigen::array<typename internal::traits<Input>::Index,
373                                internal::traits<Input>::NumDimensions>,
374             const Eigen::TensorReshapingOp<
375                 const Eigen::DSizes<typename internal::traits<Input>::Index,
376                                     internal::traits<Input>::NumDimensions>,
377                 const TensorContractionOp<
378                     const array<
379                         IndexPair<typename internal::traits<Input>::Index>, 1>,
380                     const TensorReshapingOp<
381                         const DSizes<typename internal::traits<Input>::Index,
382                                      2>,
383                         const TensorImagePatchOp<
384                             Dynamic, Dynamic,
385                             const Eigen::TensorForcedEvalOp<
386                                 const Eigen::TensorShufflingOp<
387                                     const Eigen::array<
388                                         typename internal::traits<Input>::Index,
389                                         internal::traits<Input>::NumDimensions>,
390                                     const OutputBackward>>>>,
391                     const TensorReshapingOp<
392                         const DSizes<typename internal::traits<Input>::Index,
393                                      2>,
394                         const Eigen::TensorForcedEvalOp<
395                             const Eigen::TensorShufflingOp<
396                                 const Eigen::array<
397                                     typename internal::traits<Input>::Index,
398                                     internal::traits<Input>::NumDimensions>,
399                                 const Input>>>>>>>>>::type
400 SpatialConvolutionBackwardKernel(
401     const Input& input, const OutputBackward& output_backward,
402     typename internal::traits<Input>::Index kernelRows,
403     typename internal::traits<Input>::Index kernelCols,
404     const DenseIndex row_stride = 1, const DenseIndex col_stride = 1,
405     const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) {
406   typedef typename internal::traits<Input>::Index TensorIndex;
407   typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
408   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
409                    internal::traits<Input>::NumDimensions,
410                    internal::traits<Input>::Layout, TensorIndex>>
411       in(input);
412   TensorRef<Tensor<OutScalar, internal::traits<OutputBackward>::NumDimensions,
413                    internal::traits<OutputBackward>::Layout, TensorIndex>>
414       out(output_backward);
415 
416   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
417                           internal::traits<OutputBackward>::Layout,
418                       YOU_MADE_A_PROGRAMMING_MISTAKE);
419 
420   // stride and in_stride cannot both be larger than 1
421   eigen_assert(!(row_stride > 1 && row_in_stride > 1));
422   eigen_assert(!(col_stride > 1 && col_in_stride > 1));
423 
424   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
425 
426   static const int NumDims = internal::traits<Input>::NumDimensions;
427   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
428                           internal::traits<OutputBackward>::NumDimensions,
429                       YOU_MADE_A_PROGRAMMING_MISTAKE);
430   EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE);
431 
432   const TensorIndex inputRows =
433       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
434   const TensorIndex inputCols =
435       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
436 
437   const TensorIndex outputRows = isColMajor
438                                      ? output_backward.dimension(1)
439                                      : output_backward.dimension(NumDims - 2);
440   const TensorIndex outputCols = isColMajor
441                                      ? output_backward.dimension(2)
442                                      : output_backward.dimension(NumDims - 3);
443 
444   // Number of filters to apply. This is the same as the output depth of the
445   // result
446   const TensorIndex kernelFilters =
447       isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1];
448 
449   // Number of channels. This is the same as the input depth.
450   const TensorIndex kernelChannels =
451       isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1];
452 
453   // This is the effective kernel size, taking into account the
454   // (*_in_stride - 1) zero-values inserted between consecutive kernel
455   // elements in atrous convolution
456   const TensorIndex kernelRowsEff =
457       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
458   const TensorIndex kernelColsEff =
459       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
460 
461   // Number of batches (and other dimensions) in the input tensor.
462   TensorIndex batch = 1;
463   for (int d = 3; d < NumDims; ++d) {
464     batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1);
465   }
466 
467   // Computing the forward padding
468   const TensorIndex padRows = numext::maxi<Index>(
469       0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows);
470   const TensorIndex padCols = numext::maxi<Index>(
471       0, (outputCols - 1) * col_stride + kernelColsEff - inputCols);
472 
473   TensorIndex padding_top = padRows / 2;
474   TensorIndex padding_left = padCols / 2;
475 
476   // Compute paddings for output_backward before extracting patches.
477   const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1;
478   const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1;
479 
480   const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1;
481   const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1;
482 
483   const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top;
484   const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left;
485 
486   const TensorIndex bottom_pad_rows =
487       padded_out_rows - expanded_out_rows - top_pad_rows;
488   const TensorIndex right_pad_cols =
489       padded_out_cols - expanded_out_cols - left_pad_cols;
490 
491   // Reorder output_backward dimensions.
492   array<TensorIndex, 4> output_backward_shuffle;
493   if (isColMajor) {
494     // From: [out_depth, out_rows, out_cols, batch]
495     // To:   [batch, out_rows, out_cols, out_depth]
496     output_backward_shuffle = {3, 1, 2, 0};
497   } else {
498     // From: [batch, out_cols, out_rows, out_depth]
499     // To:   [out_depth, out_cols, out_rows, batch]
500     output_backward_shuffle = {3, 1, 2, 0};
501   }
502 
503   // Reorder input dimensions.
504   array<TensorIndex, 4> input_shuffle;
505   if (isColMajor) {
506     // From: [in_depth, in_rows, in_cols, batch]
507     // To:   [in_depth, batch, in_rows, in_cols]
508     input_shuffle = {0, 3, 1, 2};
509   } else {
510     // From: [batch, in_cols, in_rows, in_depth]
511     // To:   [in_cols, in_rows, batch, in_depth]
512     input_shuffle = {1, 2, 0, 3};
513   }
514 
515   // Input is playing the role of a "kernel" in this convolution.
516   DSizes<TensorIndex, 2> input_dims;
517   if (isColMajor) {
518     input_dims[0] = kernelChannels;
519     input_dims[1] = batch * inputRows * inputCols;
520   } else {
521     input_dims[1] = kernelChannels;
522     input_dims[0] = inputCols * inputRows * batch;
523   }
524 
525   // Molds the output of the patch extraction result into a 2D tensor:
526   // - the first dimension (dims[0]): the patch values to be multiplied with the
527   // kernels
528   // - the second dimension (dims[1]): everything else
529   DSizes<TensorIndex, 2> pre_contract_dims;
530   if (isColMajor) {
531     pre_contract_dims[0] = batch * inputRows * inputCols;
532     pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters;
533   } else {
534     pre_contract_dims[1] = inputCols * inputRows * batch;
535     pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows;
536   }
537 
538   // We will contract along the collapsed dimension that contains the
539   // batch, inputRows and inputCols.
540   array<IndexPair<TensorIndex>, 1> contract_dims;
541   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
542 
543   // Dimensions after contraction.
544   DSizes<TensorIndex, NumDims> post_contract_dims;
545   if (isColMajor) {
546     post_contract_dims[0] = kernelChannels;
547     post_contract_dims[1] = kernelRows;
548     post_contract_dims[2] = kernelCols;
549     post_contract_dims[3] = kernelFilters;
550   } else {
551     post_contract_dims[0] = kernelFilters;
552     post_contract_dims[1] = kernelCols;
553     post_contract_dims[2] = kernelRows;
554     post_contract_dims[3] = kernelChannels;
555   }
556 
557   // Reorder output of contraction to a valid filter shape.
558   array<TensorIndex, 4> kernel_shuffle;
559   if (isColMajor) {
560     // From: [in_depth, kernel_rows, kernel_cols, out_depth]
561     // To:   [out_depth, in_depth, kernel_rows, kernel_cols]
562     kernel_shuffle = {3, 0, 1, 2};
563   } else {
564     // From: [out_depth, kernel_cols, kernel_rows, in_depth]
565     // To:   [kernel_cols, kernel_rows, in_depth, out_depth]
566     kernel_shuffle = {1, 2, 3, 0};
567   }
568 
569   // Reverse kernel backprop dimensions.
570   array<TensorIndex, 4> kernel_reverse;
571   if (isColMajor) {
572     kernel_reverse = {false, false, true, true};
573   } else {
574     kernel_reverse = {true, true, false, false};
575   }
576 
577   // Create convolution input (aka source of patches) from output backward
578   // tensor by shuffling dimensions.
579   const auto output_backward_shuffled =
580       output_backward.shuffle(output_backward_shuffle).eval();
581 
582   // Create convolution kernel (aka filter) from input by shuffling and
583   // reshaping.
584   const auto input_shuffled =
585       input.shuffle(input_shuffle).eval().reshape(input_dims);
586 
587   return choose(
588              Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
589              input_shuffled.contract(
590                  output_backward_shuffled
591                      .extract_image_patches(inputRows, inputCols, row_in_stride,
592                                             col_in_stride, 1, 1, row_stride,
593                                             col_stride, top_pad_rows,
594                                             bottom_pad_rows, left_pad_cols,
595                                             right_pad_cols, OutScalar(0))
596                      .reshape(pre_contract_dims),
597                  contract_dims),
598              output_backward_shuffled
599                  .extract_image_patches(
600                      inputRows, inputCols, row_in_stride, col_in_stride, 1, 1,
601                      row_stride, col_stride, top_pad_rows, bottom_pad_rows,
602                      left_pad_cols, right_pad_cols, OutScalar(0))
603                  .reshape(pre_contract_dims)
604                  .contract(input_shuffled, contract_dims))
605       .reshape(post_contract_dims)
606       .shuffle(kernel_shuffle)
607       .eval()
608       .reverse(kernel_reverse);
609 }
610 
611 }  // end namespace Eigen
612 
613 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
614