• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/kernels/eigen_volume_patch.h"
21 
22 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
23 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
24 #endif
25 
26 #include "tensorflow/core/kernels/eigen_convolution_helpers.h"
27 
28 namespace Eigen {
29 
30 namespace internal {
31 
32 // WARNING: Most of the code here implicitly assumes that the matrix is in
33 // ColMajor layout. This is guaranteed by the tensor contraction (see
34 // TensorContraction.h).
35 //
36 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
37 // We don't want to actually extract volume patches and reshape the result into
38 // a matrix (this involves allocating huge extra memory), so the patch
39 // extraction and reshape operations are implicit.
40 //
41 // TensorContractionInputMapper takes a matrix index and returns the coefficient
42 // (or the packet) of the "virtual tensor", that would be at that index if we
43 // were to actually reshape the result of patch extraction.
44 //
45 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
46 // at the given vertical and horizontal offsets.
47 //
48 // "Virtual matrix" dimensions:
49 //   *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
50 //    1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
51 //
52 // *) extracted patches are continuous in memory (innermost dimension assuming
53 //    col major layout)
54 //
55 // With this dimensions:
56 //   row - offset within a single patch (in code: patchId)
57 //   col - index of the extracted patch (in code: patchIndex)
58 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
59 //
60 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
61           typename ArgType, typename Device, typename Scalar_, typename Index,
62           typename nocontract_t, typename contract_t, int Side, int packet_size,
63           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
65     Scalar_, Index, Side,
66     TensorEvaluator<const TensorReshapingOp<NewDimension,
67                                             const TensorVolumePatchOp<
68                                                 Planes, Rows, Cols, ArgType> >,
69                     Device>,
70     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
71     inner_dim_reordered, Alignment> {
72  public:
73   typedef Scalar_ Scalar;
74   typedef TensorContractionInputMapper<
75       Scalar, Index, Side,
76       TensorEvaluator<const TensorReshapingOp<
77                           NewDimension, const TensorVolumePatchOp<
78                                             Planes, Rows, Cols, ArgType> >,
79                       Device>,
80       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
81       inner_dim_reordered, Alignment>
82       Self;
83   typedef TensorContractionSubMapper<
84       Scalar, Index, Side,
85       TensorEvaluator<const TensorReshapingOp<
86                           NewDimension, const TensorVolumePatchOp<
87                                             Planes, Rows, Cols, ArgType> >,
88                       Device>,
89       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
90       inner_dim_reordered, Alignment>
91       SubMapper;
92   typedef SubMapper VectorMapper;
93   typedef SubMapper LinearMapper;
94   typedef typename packet_traits<Scalar>::type Packet;
95 
96   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)97   TensorContractionInputMapper(
98       const TensorEvaluator<
99           const TensorReshapingOp<
100               NewDimension,
101               const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
102           Device>& tensor,
103       const nocontract_t&, const nocontract_t&, const contract_t&,
104       const contract_t&)
105       : m_impl(tensor.impl().impl()) {
106     if (internal::traits<ArgType>::Layout == ColMajor) {
107       m_patch_depth = tensor.impl().dimensions()[0];
108       m_patch_planes = tensor.impl().dimensions()[1];
109       m_patch_rows = tensor.impl().dimensions()[2];
110       m_patch_cols = tensor.impl().dimensions()[3];
111       m_num_patches = tensor.impl().dimensions()[4];
112     } else {
113       const int NumDims = tensor.impl().dimensions().size();
114       m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
115       m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
116       m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
117       m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
118       m_num_patches = tensor.impl().dimensions()[NumDims - 5];
119     }
120 
121     // Strides for navigating through the single patch.
122     m_patch_plane_stride = m_patch_depth;
123     m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
124     m_patch_col_stride = m_patch_rows * m_patch_row_stride;
125 
126     // Strides for the output tensor.
127     // IMPORTANT: These strides are used to locate an element in a patch at a
128     // depth zero (channel), which is not quite the same as "traditional"
129     // stride.
130     m_rowStride = m_patch_planes;
131     m_colStride = m_patch_rows * m_rowStride;
132     m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
133     m_otherStride = m_patchStride * m_num_patches;
134 
135     m_outputPlanes = tensor.impl().outputPlanes();
136     m_outputRows = tensor.impl().outputRows();
137     m_outputCols = tensor.impl().outputCols();
138 
139     m_outputPlanesRows = m_outputPlanes * m_outputRows;
140 
141     m_plane_strides = tensor.impl().userPlaneStride();
142     m_row_strides = tensor.impl().userRowStride();
143     m_col_strides = tensor.impl().userColStride();
144 
145     m_in_plane_strides = tensor.impl().userInPlaneStride();
146     m_in_row_strides = tensor.impl().userInRowStride();
147     m_in_col_strides = tensor.impl().userInColStride();
148 
149     m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
150     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
151     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
152 
153     if (internal::traits<ArgType>::Layout == ColMajor) {
154       m_inputDepth = tensor.impl().impl().dimensions()[0];
155       m_inputPlanes = tensor.impl().impl().dimensions()[1];
156       m_inputRows = tensor.impl().impl().dimensions()[2];
157       m_inputCols = tensor.impl().impl().dimensions()[3];
158     } else {
159       const int NumDims = tensor.impl().impl().dimensions().size();
160       m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
161       m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
162       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
163       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
164     }
165 
166     // Strides for navigating through the input tensor.
167     m_planeInputStride = m_inputDepth;
168     m_rowInputStride = m_inputDepth * m_inputPlanes;
169     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
170     m_patchInputStride =
171         m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
172 
173     m_planePaddingTop = tensor.impl().planePaddingTop();
174     m_rowPaddingTop = tensor.impl().rowPaddingTop();
175     m_colPaddingLeft = tensor.impl().colPaddingLeft();
176 
177     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
178 
179     m_fastPatchPlaneStride =
180         internal::TensorIntDivisor<Index>(m_patch_plane_stride);
181     m_fastPatchRowStride =
182         internal::TensorIntDivisor<Index>(m_patch_row_stride);
183     m_fastPatchColStride =
184         internal::TensorIntDivisor<Index>(m_patch_col_stride);
185 
186     m_fastInputPlaneStride =
187         internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
188     m_fastInputRowStride =
189         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
190     m_fastInputColStride =
191         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
192 
193     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
194     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
195 
196     m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
197     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
198     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
199     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
200     m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
201 
202     m_fastOutputPlanesRows =
203         internal::TensorIntDivisor<Index>(m_outputPlanesRows);
204   }
205 
206   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)207   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
208       : m_impl(base_mapper.m_impl) {
209     m_patch_depth = base_mapper.m_patch_depth;
210     m_patch_planes = base_mapper.m_patch_planes;
211     m_patch_rows = base_mapper.m_patch_rows;
212     m_patch_cols = base_mapper.m_patch_cols;
213     m_num_patches = base_mapper.m_num_patches;
214 
215     m_patch_plane_stride = base_mapper.m_patch_plane_stride;
216     m_patch_row_stride = base_mapper.m_patch_row_stride;
217     m_patch_col_stride = base_mapper.m_patch_col_stride;
218 
219     m_rowStride = base_mapper.m_rowStride;
220     m_colStride = base_mapper.m_colStride;
221     m_patchStride = base_mapper.m_patchStride;
222     m_otherStride = base_mapper.m_otherStride;
223 
224     m_planeInputStride = base_mapper.m_planeInputStride;
225     m_rowInputStride = base_mapper.m_rowInputStride;
226     m_colInputStride = base_mapper.m_colInputStride;
227     m_patchInputStride = base_mapper.m_patchInputStride;
228     m_otherInputStride = base_mapper.m_otherInputStride;
229 
230     m_inputDepth = base_mapper.m_inputDepth;
231     m_inputPlanes = base_mapper.m_inputPlanes;
232     m_inputRows = base_mapper.m_inputRows;
233     m_inputCols = base_mapper.m_inputCols;
234 
235     m_outputPlanes = base_mapper.m_outputPlanes;
236     m_outputRows = base_mapper.m_outputRows;
237     m_outputCols = base_mapper.m_outputCols;
238 
239     m_plane_strides = base_mapper.m_plane_strides;
240     m_row_strides = base_mapper.m_row_strides;
241     m_col_strides = base_mapper.m_col_strides;
242 
243     m_in_plane_strides = base_mapper.m_in_plane_strides;
244     m_in_row_strides = base_mapper.m_in_row_strides;
245     m_in_col_strides = base_mapper.m_in_col_strides;
246 
247     m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
248     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
249     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
250 
251     m_planePaddingTop = base_mapper.m_planePaddingTop;
252     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
253     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
254 
255     m_outputPlanesRows = base_mapper.m_outputPlanesRows;
256 
257     m_fastNumPatches = base_mapper.m_fastNumPatches;
258     m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
259     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
260     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
261     m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
262     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
263     m_fastInputColStride = base_mapper.m_fastInputColStride;
264     m_fastRowStride = base_mapper.m_fastRowStride;
265     m_fastColStride = base_mapper.m_fastColStride;
266     m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
267     m_fastOutputRows = base_mapper.m_fastOutputRows;
268     m_fastOutputCols = base_mapper.m_fastOutputCols;
269     m_fastDimZero = base_mapper.m_fastDimZero;
270     m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
271   }
272 
273   // If true, turns off some optimizations for loading packets since the image
274   // patches are "non-standard" such as there are non-trivial strides or
275   // inflations in the input.
276   EIGEN_DEVICE_FUNC
nonStandardPatches()277   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
278     return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
279            m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
280            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
281   }
282 
283   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)284   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
285     return SubMapper(*this, i, j);
286   }
287 
288   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)289   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
290     return LinearMapper(*this, i, j);
291   }
292 
293   EIGEN_DEVICE_FUNC
operator()294   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
295     Index planeIndex, rowIndex, colIndex, otherIndex;
296     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
297     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
298   }
299 
300   // Load the coefficient at the patchIndex location instead of the usual
301   // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
302   // gpu code.
303   EIGEN_DEVICE_FUNC
operator()304   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
305     Index planeIndex, rowIndex, colIndex, otherIndex;
306     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
307     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
308   }
309 
310   EIGEN_DEVICE_FUNC
loadPacket(Index row)311   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
312     Index planeIndex, rowIndex, colIndex, otherIndex;
313     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
314     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
315   }
316 
317   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
318   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
319   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)320   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
321     Index planeIndex, rowIndex, colIndex, otherIndex;
322     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
323     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
324   }
325 
326   EIGEN_DEVICE_FUNC
impl()327   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
328     return m_impl;
329   }
330 
331   EIGEN_DEVICE_FUNC
patchDepth()332   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
333   EIGEN_DEVICE_FUNC
patchPlanes()334   EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
335   EIGEN_DEVICE_FUNC
patchRows()336   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
337   EIGEN_DEVICE_FUNC
patchCols()338   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
339 
340  private:
341   friend class TensorContractionSubMapper<
342       Scalar, Index, Side,
343       TensorEvaluator<const TensorReshapingOp<
344                           NewDimension, const TensorVolumePatchOp<
345                                             Planes, Rows, Cols, ArgType> >,
346                       Device>,
347       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
348       inner_dim_reordered, Alignment>;
349 
350   // Load coefficient from a patch specified by the "within patch offset"
351   // (patchId) and the precomputed indices of the first element of the patch.
352   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)353   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
354                                        Index rowIndex, Index colIndex,
355                                        Index otherIndex) const {
356     // Find the offset of the element wrt the location of the first element.
357     const Index patchOffset = patchId / m_fastDimZero;
358 
359     const Index colOffset = patchOffset / m_fastColStride;
360     const Index inputCol = colIndex + colOffset * m_in_col_strides;
361     const Index origInputCol =
362         (m_patch_col_inflate_strides == 1)
363             ? inputCol
364             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
365 
366     const Index rowOffset =
367         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
368     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
369     const Index origInputRow =
370         (m_patch_row_inflate_strides == 1)
371             ? inputRow
372             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
373 
374     const Index planeOffset =
375         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
376     const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
377     const Index origInputPlane =
378         (m_patch_plane_inflate_strides == 1)
379             ? inputPlane
380             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
381 
382     if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
383         origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
384         origInputPlane >= m_inputPlanes ||
385         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
386         (inputRow != origInputRow * m_patch_row_inflate_strides) ||
387         (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
388       return Scalar(0);
389     }
390 
391     const Index depth = patchId - patchOffset * patchDepth();
392     const Index inputIndex = depth + origInputPlane * m_planeInputStride +
393                              origInputRow * m_rowInputStride +
394                              origInputCol * m_colInputStride + otherIndex;
395 
396     return m_impl.coeff(inputIndex);
397   }
398 
399   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
400   // and `in_strides` equal to 1 (template specialization without templates).
401   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)402   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
403                                                Index rowIndex, Index colIndex,
404                                                Index otherIndex) const {
405     eigen_assert(!nonStandardPatches());
406 
407     // Find the offset of the element wrt the location of the first element.
408     const Index patchOffset = patchId / m_fastDimZero;
409 
410     const Index colOffset = patchOffset / m_fastColStride;
411     const Index rowOffset =
412         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
413     const Index planeOffset =
414         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
415 
416     const Index inputCol = colIndex + colOffset;
417     const Index inputRow = rowIndex + rowOffset;
418     const Index inputPlane = planeIndex + planeOffset;
419 
420     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
421         inputRow >= m_inputRows || inputPlane < 0 ||
422         inputPlane >= m_inputPlanes) {
423       return Scalar(0);
424     }
425 
426     const Index depth = patchId - patchOffset * patchDepth();
427     const Index inputIndex = depth + inputPlane * m_planeInputStride +
428                              inputRow * m_rowInputStride +
429                              inputCol * m_colInputStride + otherIndex;
430 
431     return m_impl.coeff(inputIndex);
432   }
433 
434   // Load packet from a patch specified by the "within patch offset"
435   // (patchId) and the precomputed indices of the first element of the patch.
436   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)437   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
438                                         Index rowIndex, Index colIndex,
439                                         Index otherIndex) const {
440     const Index packetSize = internal::unpacket_traits<Packet>::size;
441 
442     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
443     eigen_assert(patchId <
444                  patchDepth() * patchPlanes() * patchRows() * patchCols());
445 
446     if (nonStandardPatches()) {
447       return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
448                                     otherIndex);
449     }
450     typedef decltype(m_impl) TensorEvaluatorT;
451     return loadPacketStandard<Packet, TensorEvaluatorT>(
452         patchId, planeIndex, rowIndex, colIndex, otherIndex);
453   }
454 
455   // Helper function to load a 'partial' packet - this is the single row part of
456   // a packet that is split across two rows (but single column). In the
457   // 'partial' packet, the elements corresponding to the row (specified through
458   // rowOffset) are loaded and the rest of the elements are zero-filled into the
459   // 'partial' packet. This function is called from
460   // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised
461   // only when the packet type supports masked load and when the partial packet
462   // load is available in the TensorEvaluator.
463   EIGEN_DEVICE_FUNC
loadPartialPacketStandard(Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset,Index rowOffset)464   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
465       Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex,
466       Index patchId, const Index span[], const Index patchOffsets[],
467       Index colOffset, Index rowOffset) const {
468     const Index inputCol = colIndex + colOffset;
469     const Index inputRow = rowIndex + rowOffset;
470     const Index planeOffsets[2] = {
471         patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride,
472         patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride};
473     const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
474                                   planeIndex + planeOffsets[1]};
475 
476     if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols ||
477         inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
478       // Partial packet is all zeros
479       return internal::pset1<Packet>(Scalar(0));
480     } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
481       // From inputIndex-span[0], we need to load elements starting from index
482       // span[0] all the way upto (and including) span[1].
483       const Index depth = patchId - patchOffsets[0] * patchDepth();
484       const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
485                                inputRow * m_rowInputStride +
486                                inputCol * m_colInputStride + otherIndex;
487       return m_impl.template partialPacket<Packet>(
488           inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
489     } else {
490       // Using slow path for this partial packet.
491       // We need to load elements starting from index span[0] all the way upto
492       // (and including) span[1]. We split this load into 3 parts:
493       // 0 : span[0]-1 - Zeros will be loaded for these indices
494       // span[0] : span[1] - Elements will be loaded here for these indices
495       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
496       const Index packetSize = internal::unpacket_traits<Packet>::size;
497       EIGEN_ALIGN_MAX
498       typename internal::remove_const<Scalar>::type values[packetSize];
499       for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
500       for (int i = span[0]; i < span[1] + 1; ++i)
501         values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex,
502                               colIndex, otherIndex);
503       for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
504       return internal::pload<Packet>(values);
505     }
506   }
507 
508   // Helper function to load a packet that is split across two rows (but single
509   // column). If required, this function is called from loadPacketStandard()
510   // when the packet type supports masked load and when the partial packet load
511   // is available in the TensorEvaluator.
512   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumnTwoRows(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[])513   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows(
514       Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
515       Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
516       const Index rowOffsets[]) const {
517     eigen_assert(colOffsets[1] == colOffsets[0] &&
518                  rowOffsets[1] == rowOffsets[0] + 1);
519     const Index packetSize = internal::unpacket_traits<Packet>::size;
520 
521     // Packet to load will be split into 2 parts where each part spans a single
522     // row and both the parts span the same column.
523     // First determine where to split.
524     const Index patchIdSplit =
525         (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) *
526          m_patch_depth) -
527         1;
528     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
529 
530     // patchIds[i]:          patchId corresponding to partial packet i
531     // spans[i]:             Start and end indices corresponding to the elements
532     //                       to be loaded for partial packet i
533     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
534     const Index patchIds[2] = {patchId, patchIdSplit + 1};
535     const Index spans[2][2] = {{0, patchIdSplit - patchId},
536                                {patchIdSplit - patchId + 1, packetSize - 1}};
537     const Index patchOffsets2Cols[2][2] = {
538         {patchOffsets[0], patchOffsetSplit},
539         {patchOffsetSplit + 1, patchOffsets[1]}};
540 
541     // Load partial packets and do bit-wise OR to generate required packet
542     return internal::por<Packet>(
543         loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
544                                   patchIds[0], spans[0], patchOffsets2Cols[0],
545                                   colOffsets[0], rowOffsets[0]),
546         loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
547                                   patchIds[1], spans[1], patchOffsets2Cols[1],
548                                   colOffsets[1], rowOffsets[1]));
549   }
550 
551   // Helper function to load a packet that is present in a single column and
552   // row. If required, this function is called from loadPacketStandard().
553   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumnSingleRow(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[],const Index inputCols[],const Index inputRows[])554   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow(
555       Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
556       Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
557       const Index rowOffsets[], const Index inputCols[],
558       const Index inputRows[]) const {
559     eigen_assert(colOffsets[1] == colOffsets[0] &&
560                  rowOffsets[1] == rowOffsets[0]);
561     const Index planeOffsets[2] = {
562         patchOffsets[0] - colOffsets[0] * m_colStride -
563             rowOffsets[0] * m_rowStride,
564         patchOffsets[1] - colOffsets[1] * m_colStride -
565             rowOffsets[1] * m_rowStride};
566     eigen_assert(planeOffsets[0] <= planeOffsets[1]);
567     const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
568                                   planeIndex + planeOffsets[1]};
569 
570     if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
571       return internal::pset1<Packet>(Scalar(0));
572     }
573     if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
574       const Index depth = patchId - patchOffsets[0] * patchDepth();
575       const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
576                                inputRows[0] * m_rowInputStride +
577                                inputCols[0] * m_colInputStride + otherIndex;
578       return m_impl.template packet<Unaligned>(inputIndex);
579     }
580     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
581                                   otherIndex);
582   }
583 
584   // Load standard packet from a patch specified by the "within patch offset"
585   // (patchId) and the precomputed indices of the first element of the patch.
586   // This function will be called if partial packet loading is not available
587   // for the TensorEvaluator or if the packet type does not support masked
588   // load.
589   template <typename PacketT, typename TensorEvaluatorT>
590   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
591       !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
592       PacketT>::type
loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)593   loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
594                      Index colIndex, Index otherIndex) const {
595     const Index packetSize = internal::unpacket_traits<Packet>::size;
596     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
597     eigen_assert(patchId <
598                  patchDepth() * patchPlanes() * patchRows() * patchCols());
599     eigen_assert(!nonStandardPatches());
600 
601     if ((patchDepth() % packetSize) == 0) {
602       return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
603                             otherIndex);
604     } else {
605       // Offsets and input calculation here are identical to
606       // loadCoeffStandard(...), but repeated twice.
607 
608       const Index patchOffsets[2] = {
609           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
610 
611       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
612                                    patchOffsets[1] / m_fastColStride};
613       eigen_assert(colOffsets[0] <= colOffsets[1]);
614 
615       const Index inputCols[2] = {colIndex + colOffsets[0],
616                                   colIndex + colOffsets[1]};
617       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
618         return internal::pset1<Packet>(Scalar(0));
619       }
620 
621       if (inputCols[0] == inputCols[1]) {
622         const Index rowOffsets[2] = {
623             (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
624             (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
625         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
626         const Index inputRows[2] = {rowIndex + rowOffsets[0],
627                                     rowIndex + rowOffsets[1]};
628 
629         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
630           return internal::pset1<Packet>(Scalar(0));
631         }
632 
633         if (inputRows[0] == inputRows[1]) {
634           return loadPacketStandardFromSingleColumnSingleRow(
635               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
636               colOffsets, rowOffsets, inputCols, inputRows);
637         }
638       }
639     }
640 
641     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
642                                   otherIndex);
643   }
644 
645   // Load standard packet from a patch specified by the "within patch offset"
646   // (patchId) and the precomputed indices of the first element of the patch.
647   // This function will be called if partial packet loading is available for
648   // the TensorEvaluator and if the packet type supports masked load.
649   // The only difference between this and the other case is that if the packet
650   // to load is split across two rows (but in same column), then in this case
651   // instead of going to the slow (element-by-element) load, we load two packets
652   // - each containing elements from one of the rows (rest of the elements of
653   // the packets are zeroes), and then combine these two packets to generate the
654   // required packet. The idea is to enable fast load (if possible) of these
655   // 'partial' packets.
656   template <typename PacketT, typename TensorEvaluatorT>
657   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
658       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
659       PacketT>::type
loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)660   loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
661                      Index colIndex, Index otherIndex) const {
662     const Index packetSize = internal::unpacket_traits<Packet>::size;
663     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
664     eigen_assert(patchId <
665                  patchDepth() * patchPlanes() * patchRows() * patchCols());
666     eigen_assert(!nonStandardPatches());
667 
668     if ((patchDepth() % packetSize) == 0) {
669       return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
670                             otherIndex);
671     } else {
672       // Offsets and input calculation here are identical to
673       // loadCoeffStandard(...), but repeated twice.
674 
675       const Index patchOffsets[2] = {
676           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
677 
678       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
679                                    patchOffsets[1] / m_fastColStride};
680       eigen_assert(colOffsets[0] <= colOffsets[1]);
681 
682       const Index inputCols[2] = {colIndex + colOffsets[0],
683                                   colIndex + colOffsets[1]};
684       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
685         return internal::pset1<Packet>(Scalar(0));
686       }
687 
688       if (inputCols[0] == inputCols[1]) {
689         const Index rowOffsets[2] = {
690             (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
691             (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
692         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
693         const Index inputRows[2] = {rowIndex + rowOffsets[0],
694                                     rowIndex + rowOffsets[1]};
695 
696         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
697           return internal::pset1<Packet>(Scalar(0));
698         }
699 
700         if (inputRows[0] == inputRows[1]) {
701           return loadPacketStandardFromSingleColumnSingleRow(
702               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
703               colOffsets, rowOffsets, inputCols, inputRows);
704         }
705         if (inputRows[0] + 1 == inputRows[1]) {
706           return loadPacketStandardFromSingleColumnTwoRows(
707               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
708               colOffsets, rowOffsets);
709         }
710       }
711     }
712 
713     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
714                                   otherIndex);
715   }
716 
717   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)718   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
719                                             Index rowIndex, Index colIndex,
720                                             Index otherIndex) const {
721     const Index packetSize = internal::unpacket_traits<Packet>::size;
722     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
723     eigen_assert(patchId <
724                  patchDepth() * patchPlanes() * patchRows() * patchCols());
725 
726     eigen_assert(!nonStandardPatches());
727     eigen_assert((patchDepth() % packetSize) == 0);
728 
729     // Find the offset of the element wrt the location of the first element.
730     const Index patchOffset = patchId / m_fastDimZero;
731     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
732 
733     const Index colOffset = patchOffset / m_fastColStride;
734     const Index rowOffset =
735         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
736     const Index planeOffset =
737         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
738 
739     const Index inputCol = colIndex + colOffset;
740     const Index inputRow = rowIndex + rowOffset;
741     const Index inputPlane = planeIndex + planeOffset;
742 
743     if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
744         inputCol >= m_inputCols || inputRow >= m_inputRows ||
745         inputPlane >= m_inputPlanes) {
746       return internal::pset1<Packet>(Scalar(0));
747     }
748 
749     const Index depth = patchId - patchOffset * patchDepth();
750     const Index inputIndex = depth + inputPlane * m_planeInputStride +
751                              inputRow * m_rowInputStride +
752                              inputCol * m_colInputStride + otherIndex;
753     return m_impl.template packet<Unaligned>(inputIndex);
754   }
755 
756   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)757   packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
758                          Index colIndex, Index otherIndex) const {
759     const int packetSize = internal::unpacket_traits<Packet>::size;
760     EIGEN_ALIGN_MAX
761     typename internal::remove_const<Scalar>::type values[packetSize];
762     for (int i = 0; i < packetSize; ++i) {
763       values[i] =
764           loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
765     }
766     Packet rslt = internal::pload<Packet>(values);
767     return rslt;
768   }
769 
770   // Precompute the indices (plane, row, col, other) of the first element of
771   // the given patch index, within the output tensor of the TensorVolumePatchOp.
computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)772   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
773       Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
774       Index& otherIndex) const {
775     const size_t NumInputDims = array_size<
776         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
777 
778     // Check if patchIndex might contain batch and other dimensions.
779     otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
780 
781     // Compute index of the patch within the batch (and other dimensions).
782     const Index patch3DIndex = (NumInputDims == 4)
783                                    ? patchIndex
784                                    : (patchIndex - otherIndex * m_num_patches);
785 
786     otherIndex *= m_patchInputStride;
787 
788     colIndex = patch3DIndex / m_fastOutputPlanesRows;
789     rowIndex =
790         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
791     planeIndex =
792         patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
793 
794     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
795     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
796     planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
797   }
798 
799   Index m_patch_depth;   // number of channels in the patch
800   Index m_patch_planes;  // number of planes in the patch
801   Index m_patch_rows;    // number of rows in the patch
802   Index m_patch_cols;    // number of columns in the patch
803   Index m_num_patches;   // number of patches to extract
804 
805   // Strides for navigating through the single patch.
806   Index m_patch_plane_stride;
807   Index m_patch_row_stride;
808   Index m_patch_col_stride;
809 
810   // Strides for the output tensor (depth is not the part of the stride).
811   Index m_rowStride;
812   Index m_colStride;
813   Index m_patchStride;
814   Index m_otherStride;
815 
816   Index m_planeInputStride;  // Plane stride in the input tensor
817   Index m_rowInputStride;    // Row stride in the input tensor
818   Index m_colInputStride;    // Col stride in the input tensor
819   Index m_patchInputStride;  // Patch stride in the input tensor
820   Index m_otherInputStride;
821 
822   Index m_inputDepth;   // Depth of the input tensor
823   Index m_inputPlanes;  // Number of planes in the input tensor
824   Index m_inputRows;    // Number of rows in the input tensor
825   Index m_inputCols;    // Number of cols in the input tensor
826 
827   Index m_outputPlanes;      // Number of output planes
828   Index m_outputRows;        // Number of output rows
829   Index m_outputCols;        // Number of output cols
830   Index m_outputPlanesRows;  // Cached outputPlanes * outputRows.
831 
832   Index m_plane_strides;  // User specified plane stride
833   Index m_row_strides;    // User specified row stride
834   Index m_col_strides;    // User specified col stride
835 
836   // User specified plane/row/col atrous convolution strides.
837   Index m_in_plane_strides;
838   Index m_in_row_strides;
839   Index m_in_col_strides;
840 
841   // User specified plane/row/col inflation strides in the image patch.
842   Index m_patch_plane_inflate_strides;
843   Index m_patch_row_inflate_strides;
844   Index m_patch_col_inflate_strides;
845 
846   Index m_planePaddingTop;  // Plane padding
847   Index m_rowPaddingTop;    // Row padding
848   Index m_colPaddingLeft;   // Column padding
849 
850   // Fast representation of various divisors.
851   internal::TensorIntDivisor<Index> m_fastNumPatches;
852 
853   internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
854   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
855   internal::TensorIntDivisor<Index> m_fastPatchColStride;
856 
857   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
858   internal::TensorIntDivisor<Index> m_fastInputRowStride;
859   internal::TensorIntDivisor<Index> m_fastInputColStride;
860 
861   internal::TensorIntDivisor<Index> m_fastRowStride;
862   internal::TensorIntDivisor<Index> m_fastColStride;
863 
864   internal::TensorIntDivisor<Index> m_fastDimZero;  // aka output depth
865   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
866   internal::TensorIntDivisor<Index> m_fastOutputRows;
867   internal::TensorIntDivisor<Index> m_fastOutputCols;
868   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
869 
870   const TensorEvaluator<ArgType, Device> m_impl;
871 };
872 
873 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
874           typename ArgType, typename Device, typename Scalar, typename Index,
875           typename nocontract_t, typename contract_t, int Side, int packet_size,
876           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
877 class TensorContractionSubMapper<
878     Scalar, Index, Side,
879     TensorEvaluator<const TensorReshapingOp<NewDimension,
880                                             const TensorVolumePatchOp<
881                                                 Planes, Rows, Cols, ArgType> >,
882                     Device>,
883     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
884     inner_dim_reordered, Alignment> {
885  public:
886   typedef typename packet_traits<Scalar>::type Packet;
887   typedef typename packet_traits<Scalar>::half HalfPacket;
888 
889   typedef TensorContractionInputMapper<
890       Scalar, Index, Side,
891       TensorEvaluator<const TensorReshapingOp<
892                           NewDimension, const TensorVolumePatchOp<
893                                             Planes, Rows, Cols, ArgType> >,
894                       Device>,
895       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
896       inner_dim_reordered, Alignment>
897       ParentMapper;
898   typedef TensorContractionSubMapper<
899       Scalar, Index, Side,
900       TensorEvaluator<const TensorReshapingOp<
901                           NewDimension, const TensorVolumePatchOp<
902                                             Planes, Rows, Cols, ArgType> >,
903                       Device>,
904       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
905       inner_dim_reordered, Alignment>
906       Self;
907   typedef Self LinearMapper;
908 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)909   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
910       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
911       : m_base_mapper(base_mapper),
912         m_depth_offset(vert_offset),
913         m_col_offset(horiz_offset) {
914     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
915                                      m_colIndex, m_otherIndex);
916   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)917   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
918       const Self& base_mapper, Index vert_offset, Index horiz_offset)
919       : m_base_mapper(base_mapper.m_base_mapper),
920         m_depth_offset(vert_offset + base_mapper.m_depth_offset),
921         m_col_offset(horiz_offset + base_mapper.m_col_offset) {
922     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
923                                      m_colIndex, m_otherIndex);
924   }
operator()925   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
926     return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
927                                    m_colIndex, m_otherIndex);
928   }
operator()929   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
930                                                           Index j) const {
931     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
932   }
933 
loadPacket(Index i)934   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
935     return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
936                                     m_rowIndex, m_colIndex, m_otherIndex);
937   }
loadPacket(Index i,Index j)938   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
939                                                           Index j) const {
940     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
941                                                         j + m_col_offset);
942   }
943 
944   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)945   loadCoeffStandard(Index i) const {
946     return m_base_mapper.loadCoeffStandard(
947         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
948   }
949 
loadPacketFast(Index i)950   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
951     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
952                                         m_rowIndex, m_colIndex, m_otherIndex);
953   }
954   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)955   loadPacketStandard(Index i) const {
956     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
957     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
958         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
959   }
960   template <typename Packet>
aligned(Index)961   EIGEN_DEVICE_FUNC bool aligned(Index) const {
962     return false;
963   }
964 
965   EIGEN_DEVICE_FUNC
nonStandardPatches()966   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
967     return m_base_mapper.nonStandardPatches();
968   }
969 
970   // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
971   // plane and depth index respectively that fits into the peeled_k elements
972   // starting at m_depth_offset.
973 
974   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)975   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
976     const Index max_col =
977         fastPatchColStride().divide(m_depth_offset + peeled_k);
978     return std::min<Index>(1 + max_col, patchCols());
979   }
980 
981   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)982   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
983                                    const Index col) const {
984     const Index max_row = fastPatchRowStride().divide(
985         m_depth_offset + peeled_k - col * patchColStride());
986     return std::min<Index>(1 + max_row, patchRows());
987   }
988 
989   EIGEN_DEVICE_FUNC
maxPlane(const Index peeled_k,const Index col,const Index row)990   EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
991                                      const Index row) const {
992     const Index max_plane = fastPatchPlaneStride().divide(
993         m_depth_offset + peeled_k - col * patchColStride() -
994         row * patchRowStride());
995     return std::min<Index>(1 + max_plane, patchPlanes());
996   }
997 
998   // MaxDepth uses only the remaining number of elements in the peeled_k.
999   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)1000   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
1001                                      const Index start_depth) const {
1002     return std::min<Index>(start_depth + num_elements, patchDepth());
1003   }
1004 
1005   // Every register matters in this code, so sometimes to prevent register
1006   // spilling, instead of the variable that you would expect to see, we use
1007   // another one, that is guaranteed to have the same value. E.g. patch depth is
1008   // always the same as input depth, and it's also the same as input plane
1009   // stride. Bunch of other parameters have similar relations.
1010 
1011   typedef internal::TensorIntDivisor<Index> IndexDivisor;
1012 
1013   EIGEN_DEVICE_FUNC
patchDepth()1014   EIGEN_ALWAYS_INLINE Index patchDepth() const {
1015     eigen_assert(m_base_mapper.m_patch_depth ==
1016                      m_base_mapper.m_planeInputStride &&
1017                  "Patch depth must be equal to plane input stride.");
1018     return m_base_mapper.m_planeInputStride;
1019   }
1020 
1021   EIGEN_DEVICE_FUNC
patchPlanes()1022   EIGEN_ALWAYS_INLINE Index patchPlanes() const {
1023     eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
1024                  "Patch planes must be equal to row stride.");
1025     return m_base_mapper.m_rowStride;
1026   }
1027   EIGEN_DEVICE_FUNC
patchRows()1028   EIGEN_ALWAYS_INLINE Index patchRows() const {
1029     return m_base_mapper.m_patch_rows;
1030   }
1031   EIGEN_DEVICE_FUNC
patchCols()1032   EIGEN_ALWAYS_INLINE Index patchCols() const {
1033     return m_base_mapper.m_patch_cols;
1034   }
1035 
1036   EIGEN_DEVICE_FUNC
patchPlaneStride()1037   EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
1038     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1039                  "Patch depth must be equal to patch plane stride.");
1040     return patchDepth();
1041   }
1042   EIGEN_DEVICE_FUNC
patchRowStride()1043   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
1044     return m_base_mapper.m_patch_row_stride;
1045   }
1046   EIGEN_DEVICE_FUNC
patchColStride()1047   EIGEN_ALWAYS_INLINE Index patchColStride() const {
1048     return m_base_mapper.m_patch_col_stride;
1049   }
1050 
1051   EIGEN_DEVICE_FUNC
fastPatchPlaneStride()1052   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
1053     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1054                  "Patch depth must be equal to patch plane stride.");
1055     return m_base_mapper.m_fastDimZero;  // patch_depth
1056   }
1057   EIGEN_DEVICE_FUNC
fastPatchRowStride()1058   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
1059     return m_base_mapper.m_fastPatchRowStride;
1060   }
1061   EIGEN_DEVICE_FUNC
fastPatchColStride()1062   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
1063     return m_base_mapper.m_fastPatchColStride;
1064   }
1065 
1066   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)1067   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
1068                                              const Index baseIndex) const {
1069     const Index inputIndex = depth + baseIndex;
1070     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
1071   }
1072   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)1073   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
1074                                             const Index baseIndex) const {
1075     const Index inputIndex = depth + baseIndex;
1076     return m_base_mapper.m_impl.coeff(inputIndex);
1077   }
1078 
1079   EIGEN_DEVICE_FUNC
padPlane(const Index plane)1080   EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
1081     const Index p = m_planeIndex + plane;
1082     return p < 0 || p >= m_base_mapper.m_inputPlanes;
1083   }
1084   EIGEN_DEVICE_FUNC
padRow(const Index row)1085   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
1086     const Index r = m_rowIndex + row;
1087     return r < 0 || r >= m_base_mapper.m_inputRows;
1088   }
1089   EIGEN_DEVICE_FUNC
padCol(const Index col)1090   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
1091     const Index c = m_colIndex + col;
1092     return c < 0 || c >= m_base_mapper.m_inputCols;
1093   }
1094   EIGEN_DEVICE_FUNC
baseIndex(const Index plane,const Index row,const Index col)1095   EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
1096                                       const Index col) const {
1097     const Index p = m_planeIndex + plane;
1098     const Index r = m_rowIndex + row;
1099     const Index c = m_colIndex + col;
1100     return p * m_base_mapper.m_planeInputStride +
1101            r * m_base_mapper.m_rowInputStride +
1102            c * m_base_mapper.m_colInputStride + m_otherIndex;
1103   }
1104 
1105   EIGEN_DEVICE_FUNC
planeOffset()1106   EIGEN_ALWAYS_INLINE Index planeOffset() const {
1107     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1108     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1109     const Index rowOffset =
1110         (patchOffset - colOffset * m_base_mapper.m_colStride) /
1111         m_base_mapper.m_fastRowStride;
1112     const Index planeOffset = patchOffset -
1113                               colOffset * m_base_mapper.m_colStride -
1114                               rowOffset * m_base_mapper.m_rowStride;
1115     return planeOffset;
1116   }
1117 
1118   EIGEN_DEVICE_FUNC
rowOffset()1119   EIGEN_ALWAYS_INLINE Index rowOffset() const {
1120     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1121     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1122     const Index rowOffset =
1123         (patchOffset - colOffset * m_base_mapper.m_colStride) /
1124         m_base_mapper.m_fastRowStride;
1125     return rowOffset;
1126   }
1127 
1128   EIGEN_DEVICE_FUNC
colOffset()1129   EIGEN_ALWAYS_INLINE Index colOffset() const {
1130     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1131     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1132     return colOffset;
1133   }
1134 
1135   EIGEN_DEVICE_FUNC
depthOffset()1136   EIGEN_ALWAYS_INLINE Index depthOffset() const {
1137     return m_depth_offset % patchDepth();
1138   }
1139 
1140   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)1141   getLinearMapper(Index i, Index j) const {
1142     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
1143   }
1144 
1145  private:
1146   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
1147                                      // performs better in benchmarks.
1148 
1149   Index m_depth_offset;  // First row in the input matrix
1150   Index m_col_offset;    // First col in the input matrix
1151 
1152   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
1153   // indices for the first element in a patch specified by col_offset
1154   // (see computeBaseIndices(...) for details).
1155   Index m_planeIndex;
1156   Index m_rowIndex;
1157   Index m_colIndex;
1158   Index m_otherIndex;
1159 };
1160 
1161 // Arrange a block of the right input matrix (in our case it's always a "virtual
1162 // matrix" constructed from extracted volume patches) in contiguous memory.
1163 //
1164 // Given column major input (A0 beside A1 in memory):
1165 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
1166 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1167 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1168 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1169 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1170 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1171 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1172 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1173 // A8 ...
1174 // ...
1175 //
1176 // *) A, B, C, ... - patches extracted from the original input.
1177 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1178 //
1179 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1180 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1181 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1182 // ...
1183 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1184 //
1185 // This traversal order must be the same as in default gemm_pack_rhs defined in
1186 // GeneralBlockPanelKernel.h.
1187 //
1188 // *) nr - number of registers along the 'n' dimension.
1189 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1190 //    Multiplication" paper.
1191 //
1192 // TODO(ezhulenev): Add support for squeezing reads along two innermost
1193 // dimensions (see eigen_spatial_convolutions).
1194 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1195           typename ArgType, typename Device, typename Scalar, typename Index,
1196           typename nocontract_t, typename contract_t, int packet_size,
1197           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1198           int nr>
1199 struct gemm_pack_rhs<
1200     Scalar, Index,
1201     TensorContractionSubMapper<
1202         Scalar, Index, Rhs,
1203         TensorEvaluator<const TensorReshapingOp<
1204                             NewDimension, const TensorVolumePatchOp<
1205                                               Planes, Rows, Cols, ArgType> >,
1206                         Device>,
1207         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1208         inner_dim_reordered, Alignment>,
1209     nr, ColMajor, false, false> {
1210   typedef TensorContractionSubMapper<
1211       Scalar, Index, Rhs,
1212       TensorEvaluator<const TensorReshapingOp<
1213                           NewDimension, const TensorVolumePatchOp<
1214                                             Planes, Rows, Cols, ArgType> >,
1215                       Device>,
1216       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1217       inner_dim_reordered, Alignment>
1218       SubMapper;
1219 
1220   typedef SubMapper DataMapper;
1221   typedef typename packet_traits<Scalar>::type Packet;
1222 
1223   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1224 
1225   EIGEN_DEVICE_FUNC
1226   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1227                                     Index depth, Index cols, Index stride = 0,
1228                                     Index offset = 0) const {
1229     eigen_assert(stride == 0);
1230     eigen_assert(offset == 0);
1231 
1232     const Index packet_cols4 = (cols / 4) * 4;
1233     const Index peeled_k = (depth / packet_size) * packet_size;
1234     const bool non_standard_patches = rhs.nonStandardPatches();
1235 
1236     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1237       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1238       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1239       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1240       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1241 
1242       Index k = 0;
1243       if ((packet_size % 4) == 0 && !non_standard_patches) {
1244         // FAST PATH:
1245         // Iterate over patch columns, rows and planes if we know that a single
1246         // packet do not span across multiple planes, rows or columns.
1247         if ((rhs.patchDepth() % packet_size) == 0) {
1248           const Index start_col = rhs.colOffset();
1249           const Index max_col = rhs.maxCol(peeled_k);
1250 
1251           for (Index c = start_col; c < max_col; ++c) {
1252             eigen_assert(k <= peeled_k);
1253 
1254             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1255             const Index max_row = rhs.maxRow(peeled_k, c);
1256 
1257             const bool pad_col0 = dm0.padCol(c);
1258             const bool pad_col1 = dm1.padCol(c);
1259             const bool pad_col2 = dm2.padCol(c);
1260             const bool pad_col3 = dm3.padCol(c);
1261 
1262             for (Index r = start_row; r < max_row; ++r) {
1263               eigen_assert(k <= peeled_k);
1264 
1265               const Index start_plane = ((c == start_col) && (r == start_row))
1266                                             ? rhs.planeOffset()
1267                                             : 0;
1268               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1269 
1270               const bool pad_row0 = pad_col0 || dm0.padRow(r);
1271               const bool pad_row1 = pad_col1 || dm1.padRow(r);
1272               const bool pad_row2 = pad_col2 || dm2.padRow(r);
1273               const bool pad_row3 = pad_col3 || dm3.padRow(r);
1274 
1275               for (Index p = start_plane; p < max_plane; ++p) {
1276                 eigen_assert(k <= peeled_k);
1277 
1278                 const bool pad0 = pad_row0 || dm0.padPlane(p);
1279                 const bool pad1 = pad_row1 || dm1.padPlane(p);
1280                 const bool pad2 = pad_row2 || dm2.padPlane(p);
1281                 const bool pad3 = pad_row3 || dm3.padPlane(p);
1282 
1283                 const Index idx0 = dm0.baseIndex(p, r, c);
1284                 const Index idx1 = dm1.baseIndex(p, r, c);
1285                 const Index idx2 = dm2.baseIndex(p, r, c);
1286                 const Index idx3 = dm3.baseIndex(p, r, c);
1287 
1288                 const Index start_depth =
1289                     ((c == start_col) && (r == start_row) && (p == start_plane))
1290                         ? rhs.depthOffset()
1291                         : 0;
1292                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1293                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1294 
1295                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1296                   eigen_assert(k < peeled_k);
1297                   PacketBlock<Packet, 4> kernel;
1298                   kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1299                                           : rhs.packetNoPadding(d, idx0);
1300                   kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1301                                           : rhs.packetNoPadding(d, idx1);
1302                   kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1303                                           : rhs.packetNoPadding(d, idx2);
1304                   kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1305                                           : rhs.packetNoPadding(d, idx3);
1306                   ptranspose(kernel);
1307                   pstoreu(block + 0 * packet_size, kernel.packet[0]);
1308                   pstoreu(block + 1 * packet_size, kernel.packet[1]);
1309                   pstoreu(block + 2 * packet_size, kernel.packet[2]);
1310                   pstoreu(block + 3 * packet_size, kernel.packet[3]);
1311                   block += 4 * packet_size;
1312                   k += packet_size;
1313                 }
1314               }
1315             }
1316           }
1317 
1318           // The loop above should fill peeled_k elements.
1319           eigen_assert(peeled_k == k);
1320 
1321         } else {
1322           // Packet can span multiple planes, rows or columns, so we have to go
1323           // though the slower "standard" path.
1324           for (; k < peeled_k; k += packet_size) {
1325             PacketBlock<Packet, 4> kernel;
1326             kernel.packet[0] = dm0.loadPacketStandard(k);
1327             kernel.packet[1] = dm1.loadPacketStandard(k);
1328             kernel.packet[2] = dm2.loadPacketStandard(k);
1329             kernel.packet[3] = dm3.loadPacketStandard(k);
1330             ptranspose(kernel);
1331             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1332             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1333             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1334             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1335             block += 4 * packet_size;
1336           }
1337         }
1338       }
1339 
1340       // Copy the remaining coefficients of the column block after the peeled_k.
1341       if (!non_standard_patches) {
1342         for (; k < depth; k++) {
1343           block[0] = dm0.loadCoeffStandard(k);
1344           block[1] = dm1.loadCoeffStandard(k);
1345           block[2] = dm2.loadCoeffStandard(k);
1346           block[3] = dm3.loadCoeffStandard(k);
1347           block += 4;
1348         }
1349       } else {
1350         for (; k < depth; k++) {
1351           block[0] = dm0(k);
1352           block[1] = dm1(k);
1353           block[2] = dm2(k);
1354           block[3] = dm3(k);
1355           block += 4;
1356         }
1357       }
1358     }
1359 
1360     // Copy the remaining columns one at a time (nr==1).
1361     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1362       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1363       for (Index k = 0; k < depth; k++) {
1364         *block = dm0(k);
1365         block += 1;
1366       }
1367     }
1368   }
1369 };
1370 
1371 // Template specialization for packet_size = 2. We must special-case packet
1372 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1373 //
1374 // TODO(ezhulenev): Add support for squeezing reads along two innermost
1375 // dimensions (see eigen_spatial_convolutions).
1376 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1377           typename ArgType, typename Device, typename Scalar, typename Index,
1378           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1379           bool inner_dim_reordered, int Alignment, int nr>
1380 struct gemm_pack_rhs<
1381     Scalar, Index,
1382     TensorContractionSubMapper<
1383         Scalar, Index, Rhs,
1384         TensorEvaluator<const TensorReshapingOp<
1385                             NewDimension, const TensorVolumePatchOp<
1386                                               Planes, Rows, Cols, ArgType> >,
1387                         Device>,
1388         nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1389         inner_dim_reordered, Alignment>,
1390     nr, ColMajor, false, false> {
1391   typedef TensorContractionSubMapper<
1392       Scalar, Index, Rhs,
1393       TensorEvaluator<const TensorReshapingOp<
1394                           NewDimension, const TensorVolumePatchOp<
1395                                             Planes, Rows, Cols, ArgType> >,
1396                       Device>,
1397       nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1398       inner_dim_reordered, Alignment>
1399       SubMapper;
1400   typedef SubMapper DataMapper;
1401   typedef typename packet_traits<Scalar>::type Packet;
1402 
1403   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1404 
1405   EIGEN_DEVICE_FUNC
1406   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1407                                     Index depth, Index cols, Index stride = 0,
1408                                     Index offset = 0) const {
1409     eigen_assert(stride == 0);
1410     eigen_assert(offset == 0);
1411 
1412     const int packet_size = 2;
1413 
1414     const Index packet_cols4 = (cols / 4) * 4;
1415     const Index peeled_k = (depth / packet_size) * packet_size;
1416     const bool non_standard_patches = rhs.nonStandardPatches();
1417 
1418     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1419       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1420       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1421       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1422       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1423 
1424       Index k = 0;
1425       if (!non_standard_patches) {
1426         // FAST PATH:
1427         // Iterate over patch columns, rows and planes if we know that a single
1428         // packet do not span across multiple planes, rows or columns.
1429         if ((rhs.patchDepth() % packet_size) == 0) {
1430           const Index start_col = rhs.colOffset();
1431           const Index max_col = rhs.maxCol(peeled_k);
1432 
1433           for (Index c = start_col; c < max_col; ++c) {
1434             eigen_assert(k <= peeled_k);
1435 
1436             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1437             const Index max_row = rhs.maxRow(peeled_k, c);
1438 
1439             const bool pad_col0 = dm0.padCol(c);
1440             const bool pad_col1 = dm1.padCol(c);
1441             const bool pad_col2 = dm2.padCol(c);
1442             const bool pad_col3 = dm3.padCol(c);
1443 
1444             for (Index r = start_row; r < max_row; ++r) {
1445               eigen_assert(k <= peeled_k);
1446 
1447               const Index start_plane = ((c == start_col) && (r == start_row))
1448                                             ? rhs.planeOffset()
1449                                             : 0;
1450               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1451 
1452               const bool pad_row0 = dm0.padRow(r);
1453               const bool pad_row1 = dm1.padRow(r);
1454               const bool pad_row2 = dm2.padRow(r);
1455               const bool pad_row3 = dm3.padRow(r);
1456 
1457               for (Index p = start_plane; p < max_plane; ++p) {
1458                 eigen_assert(k <= peeled_k);
1459 
1460                 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
1461                 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
1462                 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
1463                 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
1464 
1465                 const Index idx0 = dm0.baseIndex(p, r, c);
1466                 const Index idx1 = dm1.baseIndex(p, r, c);
1467                 const Index idx2 = dm2.baseIndex(p, r, c);
1468                 const Index idx3 = dm3.baseIndex(p, r, c);
1469 
1470                 const Index start_depth =
1471                     ((c == start_col) && (r == start_row) && (p == start_plane))
1472                         ? rhs.depthOffset()
1473                         : 0;
1474                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1475                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1476 
1477                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1478                   eigen_assert(k < peeled_k);
1479                   PacketBlock<Packet, 2> kernel0;
1480                   PacketBlock<Packet, 2> kernel1;
1481                   kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1482                                            : rhs.packetNoPadding(d, idx0);
1483                   kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1484                                            : rhs.packetNoPadding(d, idx1);
1485                   kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1486                                            : rhs.packetNoPadding(d, idx2);
1487                   kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1488                                            : rhs.packetNoPadding(d, idx3);
1489                   ptranspose(kernel0);
1490                   ptranspose(kernel1);
1491                   pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1492                   pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1493                   pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1494                   pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1495                   block += 4 * packet_size;
1496                   k += packet_size;
1497                 }
1498               }
1499             }
1500           }
1501 
1502           // The loop above should fill peeled_k elements.
1503           eigen_assert(peeled_k == k);
1504 
1505         } else {
1506           for (; k < peeled_k; k += packet_size) {
1507             PacketBlock<Packet, 2> kernel0;
1508             PacketBlock<Packet, 2> kernel1;
1509             kernel0.packet[0] = dm0.loadPacketStandard(k);
1510             kernel0.packet[1] = dm1.loadPacketStandard(k);
1511             kernel1.packet[0] = dm2.loadPacketStandard(k);
1512             kernel1.packet[1] = dm3.loadPacketStandard(k);
1513             ptranspose(kernel0);
1514             ptranspose(kernel1);
1515             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1516             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1517             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1518             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1519             block += 4 * packet_size;
1520           }
1521         }
1522       }
1523 
1524       // Copy the remaining coefficients of the column block after the peeled_k.
1525       if (!rhs.nonStandardPatches()) {
1526         for (; k < depth; k++) {
1527           block[0] = dm0.loadCoeffStandard(k);
1528           block[1] = dm1.loadCoeffStandard(k);
1529           block[2] = dm2.loadCoeffStandard(k);
1530           block[3] = dm3.loadCoeffStandard(k);
1531           block += 4;
1532         }
1533       } else {
1534         for (; k < depth; k++) {
1535           block[0] = dm0(k);
1536           block[1] = dm1(k);
1537           block[2] = dm2(k);
1538           block[3] = dm3(k);
1539           block += 4;
1540         }
1541       }
1542     }
1543 
1544     // Copy the remaining columns one at a time (nr==1).
1545     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1546       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1547       for (Index k = 0; k < depth; k++) {
1548         *block = dm0(k);
1549         block += 1;
1550       }
1551     }
1552   }
1553 };
1554 
1555 // Special case for non-vectorized types such as float16 (packet_size = 1).
1556 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1557           typename ArgType, typename Device, typename Scalar, typename Index,
1558           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1559           bool inner_dim_reordered, int Alignment, int nr>
1560 struct gemm_pack_rhs<
1561     Scalar, Index,
1562     TensorContractionSubMapper<
1563         Scalar, Index, Rhs,
1564         TensorEvaluator<const TensorReshapingOp<
1565                             NewDimension, const TensorVolumePatchOp<
1566                                               Planes, Rows, Cols, ArgType> >,
1567                         Device>,
1568         nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
1569         inner_dim_reordered, Alignment>,
1570     nr, ColMajor, false, false> {
1571   typedef TensorContractionSubMapper<
1572       Scalar, Index, Rhs,
1573       TensorEvaluator<const TensorReshapingOp<
1574                           NewDimension, const TensorVolumePatchOp<
1575                                             Planes, Rows, Cols, ArgType> >,
1576                       Device>,
1577       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1578       Alignment>
1579       SubMapper;
1580   typedef SubMapper DataMapper;
1581 
1582   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1583 
1584   EIGEN_DEVICE_FUNC
1585   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1586                                     Index depth, Index cols, Index stride = 0,
1587                                     Index offset = 0) const {
1588     eigen_assert(stride == 0);
1589     eigen_assert(offset == 0);
1590 
1591     const Index packet_cols4 = (cols / 4) * 4;
1592 
1593     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1594       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1595       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1596       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1597       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1598 
1599       if (!rhs.nonStandardPatches()) {
1600         for (Index k = 0; k < depth; k++) {
1601           block[0] = dm0.loadCoeffStandard(k);
1602           block[1] = dm1.loadCoeffStandard(k);
1603           block[2] = dm2.loadCoeffStandard(k);
1604           block[3] = dm3.loadCoeffStandard(k);
1605           block += 4;
1606         }
1607       } else {
1608         for (Index k = 0; k < depth; k++) {
1609           block[0] = dm0(k);
1610           block[1] = dm1(k);
1611           block[2] = dm2(k);
1612           block[3] = dm3(k);
1613           block += 4;
1614         }
1615       }
1616     }
1617 
1618     // Copy the remaining columns one at a time (nr==1).
1619     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1620       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1621       for (Index k = 0; k < depth; k++) {
1622         *block = dm0(k);
1623         block += 1;
1624       }
1625     }
1626   }
1627 };
1628 
1629 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1630 // Pack a block of the right input matrix (in our case it's always a "virtual
1631 // matrix" constructed from extracted image patches) in contiguous block in
1632 // column-major storage order. Knowing the properties of the original patch op
1633 // we can do it more efficient than the default gemm_pack_colmajor_block.
1634 //
1635 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports
1636 // squeezing reads along the 2 innermost dimensions, add it here if needed.
1637 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1638           typename ArgType, typename Device, typename Scalar,
1639           typename StorageIndex, typename nocontract_t, typename contract_t,
1640           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
1641           int Alignment>
1642 struct gemm_pack_colmajor_block<
1643     Scalar, StorageIndex,
1644     TensorContractionSubMapper<
1645         Scalar, StorageIndex, Rhs,
1646         TensorEvaluator<const TensorReshapingOp<
1647                             NewDimension, const TensorVolumePatchOp<
1648                                               Planes, Rows, Cols, ArgType> >,
1649                         Device>,
1650         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1651         inner_dim_reordered, Alignment>,
1652     ColMajor> {
1653   typedef TensorContractionSubMapper<
1654       Scalar, StorageIndex, Rhs,
1655       TensorEvaluator<const TensorReshapingOp<
1656                           NewDimension, const TensorVolumePatchOp<
1657                                             Planes, Rows, Cols, ArgType> >,
1658                       Device>,
1659       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1660       inner_dim_reordered, Alignment>
1661       SubMapper;
1662 
1663   typedef SubMapper DataMapper;
1664   typedef typename packet_traits<Scalar>::type Packet;
1665 
1666   EIGEN_DONT_INLINE
1667   void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
1668                   StorageIndex cols) {
1669     const bool standard_patches = !rhs.nonStandardPatches();
1670 
1671     if (standard_patches && rhs.patchDepth() % packet_size == 0) {
1672       packStandardPatches<true>(block, rhs, rows, cols);
1673 
1674     } else if (standard_patches) {
1675       packStandardPatches<false>(block, rhs, rows, cols);
1676 
1677     } else {
1678       // With non-standard patches we don't do any vectorized loads.
1679       // TODO(ezhulenev): It doesn't look like that we should completely give up
1680       // on packets. Make this code path faster!
1681       for (StorageIndex col = 0; col < cols; ++col) {
1682         SubMapper lm = rhs.getLinearMapper(0, col);
1683         for (StorageIndex i = 0; i < rows; ++i) {
1684           *block = lm(i);
1685           ++block;
1686         }
1687       }
1688     }
1689   }
1690 
1691  private:
1692   // Pack standard volume patches:
1693   //
1694   // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have
1695   //   depth dimension size to be a multiple of packet size, so we can skip all
1696   //   non vectorized loads and checks.
1697   //
1698   template <bool patch_depth_is_multiple_of_packet_size>
1699   EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
1700                                                const DataMapper& rhs,
1701                                                StorageIndex rows,
1702                                                StorageIndex cols) {
1703     eigen_assert(!rhs.nonStandardPatches());
1704 
1705     // Give vectorized_rows the name used in all other gemm_pack_rhs above.
1706     const Index peeled_k = (rows / packet_size) * packet_size;
1707 
1708     const Index start_col = rhs.colOffset();
1709     const Index max_col = rhs.maxCol(peeled_k);
1710 
1711     for (StorageIndex col = 0; col < cols; ++col) {
1712       SubMapper lm = rhs.getLinearMapper(0, col);
1713 
1714       Index k = 0;
1715       for (Index c = start_col; c < max_col; ++c) {
1716         eigen_assert(k <= peeled_k);
1717 
1718         const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1719         const Index max_row = rhs.maxRow(peeled_k, c);
1720         const bool pad_col = lm.padCol(c);
1721 
1722         for (Index r = start_row; r < max_row; ++r) {
1723           eigen_assert(k <= peeled_k);
1724 
1725           const Index start_plane =
1726               ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0;
1727           const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1728           const bool pad_row = pad_col || lm.padRow(r);
1729 
1730           for (Index p = start_plane; p < max_plane; ++p) {
1731             eigen_assert(k <= peeled_k);
1732 
1733             const Index start_depth =
1734                 ((c == start_col) && (r == start_row) && (p == start_plane))
1735                     ? rhs.depthOffset()
1736                     : 0;
1737             const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1738 
1739             const bool pad = pad_col || pad_row || lm.padPlane(p);
1740             const Index base_idx = lm.baseIndex(p, r, c);
1741 
1742             if (patch_depth_is_multiple_of_packet_size)
1743               eigen_assert((max_depth - start_depth) % packet_size == 0);
1744 
1745             // If patch depth is a multiple of packet size, it's guaranteed that
1746             // we can process all values in depth dimension with packets.
1747             const Index max_vectorized_depth =
1748                 patch_depth_is_multiple_of_packet_size
1749                     ? max_depth
1750                     : max_depth - packet_size;
1751 
1752             Index d = start_depth;
1753 
1754             // 1. Process depth dimension with vectorized instructions.
1755             for (; d < max_vectorized_depth; d += packet_size) {
1756               eigen_assert(k < peeled_k);
1757               const Packet packet = pad ? pset1<Packet>(Scalar(0))
1758                                         : rhs.packetNoPadding(d, base_idx);
1759               internal::pstoreu(block, packet);
1760               block += packet_size;
1761               k += packet_size;
1762             }
1763 
1764             // 2. Finish with coefficients.
1765             if (!patch_depth_is_multiple_of_packet_size) {
1766               for (; d < max_depth; d++) {
1767                 eigen_assert(k < peeled_k);
1768                 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx);
1769                 ++block;
1770                 ++k;
1771               }
1772             }
1773           }
1774         }
1775       }
1776 
1777       // The loop above should fill peeled_k elements.
1778       eigen_assert(peeled_k == k);
1779 
1780       // Fill remaining elements using loadCoeffStandard.
1781       for (; k < rows; ++k) {
1782         *block = lm.loadCoeffStandard(k);
1783         ++block;
1784       }
1785     }
1786   }
1787 };
1788 #endif  // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1789 
1790 }  // namespace internal
1791 
1792 /** CuboidConvolution
1793  * \ingroup CXX11_NeuralNetworks_Module
1794  *
1795  * \brief Applies a 3D convolution over a multichannel input voxel block.
1796  *
1797  * The input parameter is expected to be a tensor with a rank of 4 or more
1798  * (channels, depth, height, width, and optionally others).
1799  * The kernel parameter is expected to be a 5D tensor (filters, channels,
1800  * kernel_depth, kernel_height, kernel_width).
1801  * The result can be assigned to a tensor of rank equal to the rank of the
1802  * input. The dimensions of the result will be filters, depth, height, width
1803  * (and others if applicable).
1804  *
1805  * The input and kernel have to be in the same layout, and both row-major and
1806  * col-major are supported. The shapes given above are for col-major layout.
1807  * For row-major, all dimensions should be reversed.
1808  *
1809  * It is possible to swap the order of the depth, width, and height dimensions
1810  * provided that the same order is used in the input, the kernel, and the
1811  * output.
1812  */
1813 template <typename Input, typename Kernel>
1814 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1815     internal::traits<Input>::Layout == ColMajor,
1816     TensorReshapingOp<
1817         const DSizes<typename internal::traits<Input>::Index,
1818                      internal::traits<Input>::NumDimensions>,
1819         const TensorContractionOp<
1820             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1821             const TensorReshapingOp<
1822                 const DSizes<typename internal::traits<Input>::Index, 2>,
1823                 const Kernel>,
1824             const TensorReshapingOp<
1825                 const DSizes<typename internal::traits<Input>::Index, 2>,
1826                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1827                                           const Input> > > >,
1828     TensorReshapingOp<
1829         const DSizes<typename internal::traits<Input>::Index,
1830                      internal::traits<Input>::NumDimensions>,
1831         const TensorContractionOp<
1832             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1833             const TensorReshapingOp<
1834                 const DSizes<typename internal::traits<Input>::Index, 2>,
1835                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1836                                           const Input> >,
1837             const TensorReshapingOp<
1838                 const DSizes<typename internal::traits<Input>::Index, 2>,
1839                 const Kernel> > > >::type
1840 CuboidConvolution(const Input& input, const Kernel& kernel,
1841                   const Index stridePlanes = 1, const Index strideRows = 1,
1842                   const Index strideCols = 1,
1843                   const PaddingType padding_type = PADDING_SAME) {
1844   typedef typename internal::traits<Input>::Index TensorIndex;
1845   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1846                    internal::traits<Input>::NumDimensions,
1847                    internal::traits<Input>::Layout, TensorIndex> >
1848       in(input);
1849   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1850                    internal::traits<Kernel>::NumDimensions,
1851                    internal::traits<Kernel>::Layout, TensorIndex> >
1852       kern(kernel);
1853 
1854   EIGEN_STATIC_ASSERT(
1855       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1856       YOU_MADE_A_PROGRAMMING_MISTAKE);
1857   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1858   static const int NumDims = internal::traits<Input>::NumDimensions;
1859 
1860   // Number of filters to apply. This is the same as the output depth of the
1861   // result.
1862   const TensorIndex kernelFilters =
1863       isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
1864   const TensorIndex kernelChannels =
1865       isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
1866 
1867   // Spatial size of the kernel.
1868   const TensorIndex kernelPlanes =
1869       isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
1870   const TensorIndex kernelRows =
1871       isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
1872   const TensorIndex kernelCols =
1873       isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
1874 
1875   if (isColMajor) {
1876     eigen_assert(kernelChannels == in.dimension(0));
1877   } else {
1878     eigen_assert(kernelChannels == in.dimension(NumDims - 1));
1879   }
1880 
1881   const TensorIndex inputPlanes =
1882       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1883   const TensorIndex inputRows =
1884       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1885   const TensorIndex inputCols =
1886       isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
1887 
1888   TensorIndex out_planes;
1889   TensorIndex out_height;
1890   TensorIndex out_width;
1891   switch (padding_type) {
1892     case PADDING_VALID:
1893       out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
1894                                 static_cast<TensorIndex>(stridePlanes));
1895       out_height = Eigen::divup(inputRows - kernelRows + 1,
1896                                 static_cast<TensorIndex>(strideRows));
1897       out_width = Eigen::divup(inputCols - kernelCols + 1,
1898                                static_cast<TensorIndex>(strideCols));
1899       break;
1900     case PADDING_SAME:
1901       out_planes =
1902           Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
1903       out_height =
1904           Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
1905       out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
1906       break;
1907     default:
1908       out_planes = 0;
1909       out_height = 0;
1910       out_width = 0;
1911       eigen_assert(false && "unexpected padding");
1912   }
1913 
1914   DSizes<TensorIndex, 2> kernel_dims;
1915   if (isColMajor) {
1916     kernel_dims[0] = kernelFilters;
1917     kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1918   } else {
1919     kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1920     kernel_dims[1] = kernelFilters;
1921   }
1922 
1923   // Molds the output of the patch extraction result into a 2D tensor:
1924   // - the first dimension (dims[0]): the patch values to be multiplied with the
1925   // kernels
1926   // - the second dimension (dims[1]): everything else
1927   DSizes<TensorIndex, 2> pre_contract_dims;
1928   if (isColMajor) {
1929     pre_contract_dims[0] =
1930         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1931     pre_contract_dims[1] = out_planes * out_height * out_width;
1932     for (int i = 4; i < NumDims; ++i) {
1933       pre_contract_dims[1] *= in.dimension(i);
1934     }
1935   } else {
1936     pre_contract_dims[1] =
1937         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1938     pre_contract_dims[0] = out_planes * out_height * out_width;
1939     for (int i = 0; i < NumDims - 4; ++i) {
1940       pre_contract_dims[0] *= in.dimension(i);
1941     }
1942   }
1943 
1944   array<IndexPair<TensorIndex>, 1> contract_dims;
1945   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1946 
1947   // Molds the output of the contraction into the shape expected by the user
1948   // (assuming ColMajor):
1949   // - 1st dim: kernel filters
1950   // - 2nd dim: output depth
1951   // - 3nd dim: output height
1952   // - 4rd dim: output width
1953   // - 5th dim and beyond: everything else including batch size
1954   DSizes<TensorIndex, NumDims> post_contract_dims;
1955   if (isColMajor) {
1956     post_contract_dims[0] = kernelFilters;
1957     post_contract_dims[1] = out_planes;
1958     post_contract_dims[2] = out_height;
1959     post_contract_dims[3] = out_width;
1960     for (int i = 4; i < NumDims; ++i) {
1961       post_contract_dims[i] = in.dimension(i);
1962     }
1963   } else {
1964     post_contract_dims[NumDims - 1] = kernelFilters;
1965     post_contract_dims[NumDims - 2] = out_planes;
1966     post_contract_dims[NumDims - 3] = out_height;
1967     post_contract_dims[NumDims - 4] = out_width;
1968     for (int i = 0; i < NumDims - 4; ++i) {
1969       post_contract_dims[i] = in.dimension(i);
1970     }
1971   }
1972 
1973   return choose(
1974       Cond<internal::traits<Input>::Layout == ColMajor>(),
1975       kernel.reshape(kernel_dims)
1976           .contract(input
1977                         .extract_volume_patches(
1978                             kernelPlanes, kernelRows, kernelCols, stridePlanes,
1979                             strideRows, strideCols, padding_type)
1980                         .reshape(pre_contract_dims),
1981                     contract_dims)
1982           .reshape(post_contract_dims),
1983       input
1984           .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
1985                                   stridePlanes, strideRows, strideCols,
1986                                   padding_type)
1987           .reshape(pre_contract_dims)
1988           .contract(kernel.reshape(kernel_dims), contract_dims)
1989           .reshape(post_contract_dims));
1990 }
1991 
1992 }  // end namespace Eigen
1993 
1994 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
1995