• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/kernels/eigen_volume_patch.h"
21 
22 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
23 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
24 #endif
25 
26 namespace Eigen {
27 
28 namespace internal {
29 
30 // WARNING: Most of the code here implicitly assumes that the matrix is in
31 // ColMajor layout. This is guaranteed by the tensor contraction (see
32 // TensorContraction.h).
33 //
34 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
35 // We don't want to actually extract volume patches and reshape the result into
36 // a matrix (this involves allocating huge extra memory), so the patch
37 // extraction and reshape operations are implicit.
38 //
39 // TensorContractionInputMapper takes a matrix index and returns the coefficient
40 // (or the packet) of the "virtual tensor", that would be at that index if we
41 // were to actually reshape the result of patch extraction.
42 //
43 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
44 // at the given vertical and horizontal offsets.
45 //
46 // "Virtual matrix" dimensions:
47 //   *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
48 //    1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
49 //
50 // *) extracted patches are continuous in memory (innermost dimension assuming
51 //    col major layout)
52 //
53 // With this dimensions:
54 //   row - offset within a single patch (in code: patchId)
55 //   col - index of the extracted patch (in code: patchIndex)
56 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
57 //
58 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
59           typename ArgType, typename Device, typename Scalar_, typename Index,
60           typename nocontract_t, typename contract_t, int Side, int packet_size,
61           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
62 class TensorContractionInputMapper<
63     Scalar_, Index, Side,
64     TensorEvaluator<const TensorReshapingOp<NewDimension,
65                                             const TensorVolumePatchOp<
66                                                 Planes, Rows, Cols, ArgType> >,
67                     Device>,
68     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
69     inner_dim_reordered, Alignment> {
70  public:
71   typedef Scalar_ Scalar;
72   typedef TensorContractionInputMapper<
73       Scalar, Index, Side,
74       TensorEvaluator<const TensorReshapingOp<
75                           NewDimension, const TensorVolumePatchOp<
76                                             Planes, Rows, Cols, ArgType> >,
77                       Device>,
78       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
79       inner_dim_reordered, Alignment>
80       Self;
81   typedef TensorContractionSubMapper<
82       Scalar, Index, Side,
83       TensorEvaluator<const TensorReshapingOp<
84                           NewDimension, const TensorVolumePatchOp<
85                                             Planes, Rows, Cols, ArgType> >,
86                       Device>,
87       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
88       inner_dim_reordered, Alignment>
89       SubMapper;
90   typedef SubMapper VectorMapper;
91   typedef SubMapper LinearMapper;
92   typedef typename packet_traits<Scalar>::type Packet;
93 
94   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)95   TensorContractionInputMapper(
96       const TensorEvaluator<
97           const TensorReshapingOp<
98               NewDimension,
99               const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
100           Device>& tensor,
101       const nocontract_t&, const nocontract_t&, const contract_t&,
102       const contract_t&)
103       : m_impl(tensor.impl().impl()) {
104     if (internal::traits<ArgType>::Layout == ColMajor) {
105       m_patch_depth = tensor.impl().dimensions()[0];
106       m_patch_planes = tensor.impl().dimensions()[1];
107       m_patch_rows = tensor.impl().dimensions()[2];
108       m_patch_cols = tensor.impl().dimensions()[3];
109       m_num_patches = tensor.impl().dimensions()[4];
110     } else {
111       const int NumDims = tensor.impl().dimensions().size();
112       m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
113       m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
114       m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
115       m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
116       m_num_patches = tensor.impl().dimensions()[NumDims - 5];
117     }
118 
119     // Strides for navigating through the single patch.
120     m_patch_plane_stride = m_patch_depth;
121     m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
122     m_patch_col_stride = m_patch_rows * m_patch_row_stride;
123 
124     // Strides for the output tensor.
125     // IMPORTANT: These strides are used to locate an element in a patch at a
126     // depth zero (channel), which is not quite the same as "traditional"
127     // stride.
128     m_rowStride = m_patch_planes;
129     m_colStride = m_patch_rows * m_rowStride;
130     m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
131     m_otherStride = m_patchStride * m_num_patches;
132 
133     m_outputPlanes = tensor.impl().outputPlanes();
134     m_outputRows = tensor.impl().outputRows();
135     m_outputCols = tensor.impl().outputCols();
136 
137     m_outputPlanesRows = m_outputPlanes * m_outputRows;
138 
139     m_plane_strides = tensor.impl().userPlaneStride();
140     m_row_strides = tensor.impl().userRowStride();
141     m_col_strides = tensor.impl().userColStride();
142 
143     m_in_plane_strides = tensor.impl().userInPlaneStride();
144     m_in_row_strides = tensor.impl().userInRowStride();
145     m_in_col_strides = tensor.impl().userInColStride();
146 
147     m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
148     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
149     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
150 
151     if (internal::traits<ArgType>::Layout == ColMajor) {
152       m_inputDepth = tensor.impl().impl().dimensions()[0];
153       m_inputPlanes = tensor.impl().impl().dimensions()[1];
154       m_inputRows = tensor.impl().impl().dimensions()[2];
155       m_inputCols = tensor.impl().impl().dimensions()[3];
156     } else {
157       const int NumDims = tensor.impl().impl().dimensions().size();
158       m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
159       m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
160       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
161       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
162     }
163 
164     // Strides for navigating through the input tensor.
165     m_planeInputStride = m_inputDepth;
166     m_rowInputStride = m_inputDepth * m_inputPlanes;
167     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
168     m_patchInputStride =
169         m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
170 
171     m_planePaddingTop = tensor.impl().planePaddingTop();
172     m_rowPaddingTop = tensor.impl().rowPaddingTop();
173     m_colPaddingLeft = tensor.impl().colPaddingLeft();
174 
175     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
176 
177     m_fastPatchPlaneStride =
178         internal::TensorIntDivisor<Index>(m_patch_plane_stride);
179     m_fastPatchRowStride =
180         internal::TensorIntDivisor<Index>(m_patch_row_stride);
181     m_fastPatchColStride =
182         internal::TensorIntDivisor<Index>(m_patch_col_stride);
183 
184     m_fastInputPlaneStride =
185         internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
186     m_fastInputRowStride =
187         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
188     m_fastInputColStride =
189         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
190 
191     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
192     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
193 
194     m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
195     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
196     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
197     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
198     m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
199 
200     m_fastOutputPlanesRows =
201         internal::TensorIntDivisor<Index>(m_outputPlanesRows);
202   }
203 
204   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)205   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
206       : m_impl(base_mapper.m_impl) {
207     m_patch_depth = base_mapper.m_patch_depth;
208     m_patch_planes = base_mapper.m_patch_planes;
209     m_patch_rows = base_mapper.m_patch_rows;
210     m_patch_cols = base_mapper.m_patch_cols;
211     m_num_patches = base_mapper.m_num_patches;
212 
213     m_patch_plane_stride = base_mapper.m_patch_plane_stride;
214     m_patch_row_stride = base_mapper.m_patch_row_stride;
215     m_patch_col_stride = base_mapper.m_patch_col_stride;
216 
217     m_rowStride = base_mapper.m_rowStride;
218     m_colStride = base_mapper.m_colStride;
219     m_patchStride = base_mapper.m_patchStride;
220     m_otherStride = base_mapper.m_otherStride;
221 
222     m_planeInputStride = base_mapper.m_planeInputStride;
223     m_rowInputStride = base_mapper.m_rowInputStride;
224     m_colInputStride = base_mapper.m_colInputStride;
225     m_patchInputStride = base_mapper.m_patchInputStride;
226     m_otherInputStride = base_mapper.m_otherInputStride;
227 
228     m_inputDepth = base_mapper.m_inputDepth;
229     m_inputPlanes = base_mapper.m_inputPlanes;
230     m_inputRows = base_mapper.m_inputRows;
231     m_inputCols = base_mapper.m_inputCols;
232 
233     m_outputPlanes = base_mapper.m_outputPlanes;
234     m_outputRows = base_mapper.m_outputRows;
235     m_outputCols = base_mapper.m_outputCols;
236 
237     m_plane_strides = base_mapper.m_plane_strides;
238     m_row_strides = base_mapper.m_row_strides;
239     m_col_strides = base_mapper.m_col_strides;
240 
241     m_in_plane_strides = base_mapper.m_in_plane_strides;
242     m_in_row_strides = base_mapper.m_in_row_strides;
243     m_in_col_strides = base_mapper.m_in_col_strides;
244 
245     m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
246     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
247     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
248 
249     m_planePaddingTop = base_mapper.m_planePaddingTop;
250     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
251     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
252 
253     m_outputPlanesRows = base_mapper.m_outputPlanesRows;
254 
255     m_fastNumPatches = base_mapper.m_fastNumPatches;
256     m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
257     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
258     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
259     m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
260     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
261     m_fastInputColStride = base_mapper.m_fastInputColStride;
262     m_fastRowStride = base_mapper.m_fastRowStride;
263     m_fastColStride = base_mapper.m_fastColStride;
264     m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
265     m_fastOutputRows = base_mapper.m_fastOutputRows;
266     m_fastOutputCols = base_mapper.m_fastOutputCols;
267     m_fastDimZero = base_mapper.m_fastDimZero;
268     m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
269   }
270 
271   // If true, turns off some optimizations for loading packets since the image
272   // patches are "non-standard" such as there are non-trivial strides or
273   // inflations in the input.
274   EIGEN_DEVICE_FUNC
nonStandardPatches()275   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
276     return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
277            m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
278            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
279   }
280 
281   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)282   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
283     return SubMapper(*this, i, j);
284   }
285 
286   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)287   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
288     return LinearMapper(*this, i, j);
289   }
290 
291   EIGEN_DEVICE_FUNC
operator()292   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
293     Index planeIndex, rowIndex, colIndex, otherIndex;
294     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
295     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
296   }
297 
298   // Load the coefficient at the patchIndex location instead of the usual
299   // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
300   // gpu code.
301   EIGEN_DEVICE_FUNC
operator()302   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
303     Index planeIndex, rowIndex, colIndex, otherIndex;
304     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
305     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
306   }
307 
308   EIGEN_DEVICE_FUNC
loadPacket(Index row)309   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
310     Index planeIndex, rowIndex, colIndex, otherIndex;
311     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
312     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
313   }
314 
315   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
316   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
317   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)318   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
319     Index planeIndex, rowIndex, colIndex, otherIndex;
320     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
321     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
322   }
323 
324   EIGEN_DEVICE_FUNC
impl()325   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
326     return m_impl;
327   }
328 
329   EIGEN_DEVICE_FUNC
patchDepth()330   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
331   EIGEN_DEVICE_FUNC
patchPlanes()332   EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
333   EIGEN_DEVICE_FUNC
patchRows()334   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
335   EIGEN_DEVICE_FUNC
patchCols()336   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
337 
338  private:
339   friend class TensorContractionSubMapper<
340       Scalar, Index, Side,
341       TensorEvaluator<const TensorReshapingOp<
342                           NewDimension, const TensorVolumePatchOp<
343                                             Planes, Rows, Cols, ArgType> >,
344                       Device>,
345       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
346       inner_dim_reordered, Alignment>;
347 
348   // Load coefficient from a patch specified by the "within patch offset"
349   // (patchId) and the precomputed indices of the first element of the patch.
350   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)351   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
352                                        Index rowIndex, Index colIndex,
353                                        Index otherIndex) const {
354     // Find the offset of the element wrt the location of the first element.
355     const Index patchOffset = patchId / m_fastDimZero;
356 
357     const Index colOffset = patchOffset / m_fastColStride;
358     const Index inputCol = colIndex + colOffset * m_in_col_strides;
359     const Index origInputCol =
360         (m_patch_col_inflate_strides == 1)
361             ? inputCol
362             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
363 
364     const Index rowOffset =
365         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
366     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
367     const Index origInputRow =
368         (m_patch_row_inflate_strides == 1)
369             ? inputRow
370             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
371 
372     const Index planeOffset =
373         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
374     const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
375     const Index origInputPlane =
376         (m_patch_plane_inflate_strides == 1)
377             ? inputPlane
378             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
379 
380     if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
381         origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
382         origInputPlane >= m_inputPlanes ||
383         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
384         (inputRow != origInputRow * m_patch_row_inflate_strides) ||
385         (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
386       return Scalar(0);
387     }
388 
389     const Index depth = patchId - patchOffset * patchDepth();
390     const Index inputIndex = depth + origInputPlane * m_planeInputStride +
391                              origInputRow * m_rowInputStride +
392                              origInputCol * m_colInputStride + otherIndex;
393 
394     return m_impl.coeff(inputIndex);
395   }
396 
397   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
398   // and `in_strides` equal to 1 (template specialization without templates).
399   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)400   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
401                                                Index rowIndex, Index colIndex,
402                                                Index otherIndex) const {
403     eigen_assert(!nonStandardPatches());
404 
405     // Find the offset of the element wrt the location of the first element.
406     const Index patchOffset = patchId / m_fastDimZero;
407 
408     const Index colOffset = patchOffset / m_fastColStride;
409     const Index rowOffset =
410         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
411     const Index planeOffset =
412         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
413 
414     const Index inputCol = colIndex + colOffset;
415     const Index inputRow = rowIndex + rowOffset;
416     const Index inputPlane = planeIndex + planeOffset;
417 
418     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
419         inputRow >= m_inputRows || inputPlane < 0 ||
420         inputPlane >= m_inputPlanes) {
421       return Scalar(0);
422     }
423 
424     const Index depth = patchId - patchOffset * patchDepth();
425     const Index inputIndex = depth + inputPlane * m_planeInputStride +
426                              inputRow * m_rowInputStride +
427                              inputCol * m_colInputStride + otherIndex;
428 
429     return m_impl.coeff(inputIndex);
430   }
431 
432   // Load packet from a patch specified by the "within patch offset"
433   // (patchId) and the precomputed indices of the first element of the patch.
434   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)435   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
436                                         Index rowIndex, Index colIndex,
437                                         Index otherIndex) const {
438     const Index packetSize = internal::unpacket_traits<Packet>::size;
439 
440     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
441     eigen_assert(patchId <
442                  patchDepth() * patchPlanes() * patchRows() * patchCols());
443 
444     if (nonStandardPatches()) {
445       return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
446                                     otherIndex);
447     }
448     return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
449                               otherIndex);
450   }
451 
452   EIGEN_DEVICE_FUNC
loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)453   EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
454                                                 Index rowIndex, Index colIndex,
455                                                 Index otherIndex) const {
456     const Index packetSize = internal::unpacket_traits<Packet>::size;
457     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
458     eigen_assert(patchId <
459                  patchDepth() * patchPlanes() * patchRows() * patchCols());
460     eigen_assert(!nonStandardPatches());
461 
462     if ((patchDepth() % packetSize) == 0) {
463       return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
464                             otherIndex);
465     } else {
466       // Offsets and input calculation here are identical to
467       // loadCoeffStandard(...), but repeated twice.
468 
469       const Index patchOffsets[2] = {
470           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
471 
472       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
473                                    patchOffsets[1] / m_fastColStride};
474       eigen_assert(colOffsets[0] <= colOffsets[1]);
475 
476       const Index inputCols[2] = {colIndex + colOffsets[0],
477                                   colIndex + colOffsets[1]};
478       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
479         return internal::pset1<Packet>(Scalar(0));
480       }
481 
482       if (inputCols[0] == inputCols[1]) {
483         const Index rowOffsets[2] = {
484             (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
485             (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
486         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
487         const Index inputRows[2] = {rowIndex + rowOffsets[0],
488                                     rowIndex + rowOffsets[1]};
489 
490         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
491           return internal::pset1<Packet>(Scalar(0));
492         }
493 
494         if (inputRows[0] == inputRows[1]) {
495           const Index planeOffsets[2] = {
496               patchOffsets[0] - colOffsets[0] * m_colStride -
497                   rowOffsets[0] * m_rowStride,
498               patchOffsets[1] - colOffsets[1] * m_colStride -
499                   rowOffsets[1] * m_rowStride};
500           eigen_assert(planeOffsets[0] <= planeOffsets[1]);
501           const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
502                                         planeIndex + planeOffsets[1]};
503 
504           if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
505             return internal::pset1<Packet>(Scalar(0));
506           }
507 
508           if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
509             const Index depth = patchId - patchOffsets[0] * patchDepth();
510             const Index inputIndex =
511                 depth + inputPlanes[0] * m_planeInputStride +
512                 inputRows[0] * m_rowInputStride +
513                 inputCols[0] * m_colInputStride + otherIndex;
514             return m_impl.template packet<Unaligned>(inputIndex);
515           }
516         }
517       }
518     }
519 
520     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
521                                   otherIndex);
522   }
523 
524   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)525   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
526                                             Index rowIndex, Index colIndex,
527                                             Index otherIndex) const {
528     const Index packetSize = internal::unpacket_traits<Packet>::size;
529     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
530     eigen_assert(patchId <
531                  patchDepth() * patchPlanes() * patchRows() * patchCols());
532 
533     eigen_assert(!nonStandardPatches());
534     eigen_assert((patchDepth() % packetSize) == 0);
535 
536     // Find the offset of the element wrt the location of the first element.
537     const Index patchOffset = patchId / m_fastDimZero;
538     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
539 
540     const Index colOffset = patchOffset / m_fastColStride;
541     const Index rowOffset =
542         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
543     const Index planeOffset =
544         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
545 
546     const Index inputCol = colIndex + colOffset;
547     const Index inputRow = rowIndex + rowOffset;
548     const Index inputPlane = planeIndex + planeOffset;
549 
550     if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
551         inputCol >= m_inputCols || inputRow >= m_inputRows ||
552         inputPlane >= m_inputPlanes) {
553       return internal::pset1<Packet>(Scalar(0));
554     }
555 
556     const Index depth = patchId - patchOffset * patchDepth();
557     const Index inputIndex = depth + inputPlane * m_planeInputStride +
558                              inputRow * m_rowInputStride +
559                              inputCol * m_colInputStride + otherIndex;
560     return m_impl.template packet<Unaligned>(inputIndex);
561   }
562 
563   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)564   packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
565                          Index colIndex, Index otherIndex) const {
566     const int packetSize = internal::unpacket_traits<Packet>::size;
567     EIGEN_ALIGN_MAX
568     typename internal::remove_const<Scalar>::type values[packetSize];
569     for (int i = 0; i < packetSize; ++i) {
570       values[i] =
571           loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
572     }
573     Packet rslt = internal::pload<Packet>(values);
574     return rslt;
575   }
576 
577   // Precompute the indices (plane, row, col, other) of the first element of
578   // the given patch index, within the output tensor of the TensorVolumePatchOp.
computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)579   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
580       Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
581       Index& otherIndex) const {
582     const size_t NumInputDims = array_size<
583         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
584 
585     // Check if patchIndex might contain batch and other dimensions.
586     otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
587 
588     // Compute index of the patch within the batch (and other dimensions).
589     const Index patch3DIndex = (NumInputDims == 4)
590                                    ? patchIndex
591                                    : (patchIndex - otherIndex * m_num_patches);
592 
593     otherIndex *= m_patchInputStride;
594 
595     colIndex = patch3DIndex / m_fastOutputPlanesRows;
596     rowIndex =
597         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
598     planeIndex =
599         patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
600 
601     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
602     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
603     planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
604   }
605 
606   Index m_patch_depth;   // number of channels in the patch
607   Index m_patch_planes;  // number of planes in the patch
608   Index m_patch_rows;    // number of rows in the patch
609   Index m_patch_cols;    // number of columns in the patch
610   Index m_num_patches;   // number of patches to extract
611 
612   // Strides for navigating through the single patch.
613   Index m_patch_plane_stride;
614   Index m_patch_row_stride;
615   Index m_patch_col_stride;
616 
617   // Strides for the output tensor (depth is not the part of the stride).
618   Index m_rowStride;
619   Index m_colStride;
620   Index m_patchStride;
621   Index m_otherStride;
622 
623   Index m_planeInputStride;  // Plane stride in the input tensor
624   Index m_rowInputStride;    // Row stride in the input tensor
625   Index m_colInputStride;    // Col stride in the input tensor
626   Index m_patchInputStride;  // Patch stride in the input tensor
627   Index m_otherInputStride;
628 
629   Index m_inputDepth;   // Depth of the input tensor
630   Index m_inputPlanes;  // Number of planes in the input tensor
631   Index m_inputRows;    // Number of rows in the input tensor
632   Index m_inputCols;    // Number of cols in the input tensor
633 
634   Index m_outputPlanes;      // Number of output planes
635   Index m_outputRows;        // Number of output rows
636   Index m_outputCols;        // Number of output cols
637   Index m_outputPlanesRows;  // Cached outputPlanes * outputRows.
638 
639   Index m_plane_strides;  // User specified plane stride
640   Index m_row_strides;    // User specified row stride
641   Index m_col_strides;    // User specified col stride
642 
643   // User specified plane/row/col atrous convolution strides.
644   Index m_in_plane_strides;
645   Index m_in_row_strides;
646   Index m_in_col_strides;
647 
648   // User specified plane/row/col inflation strides in the image patch.
649   Index m_patch_plane_inflate_strides;
650   Index m_patch_row_inflate_strides;
651   Index m_patch_col_inflate_strides;
652 
653   Index m_planePaddingTop;  // Plane padding
654   Index m_rowPaddingTop;    // Row padding
655   Index m_colPaddingLeft;   // Column padding
656 
657   // Fast representation of various divisors.
658   internal::TensorIntDivisor<Index> m_fastNumPatches;
659 
660   internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
661   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
662   internal::TensorIntDivisor<Index> m_fastPatchColStride;
663 
664   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
665   internal::TensorIntDivisor<Index> m_fastInputRowStride;
666   internal::TensorIntDivisor<Index> m_fastInputColStride;
667 
668   internal::TensorIntDivisor<Index> m_fastRowStride;
669   internal::TensorIntDivisor<Index> m_fastColStride;
670 
671   internal::TensorIntDivisor<Index> m_fastDimZero;  // aka output depth
672   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
673   internal::TensorIntDivisor<Index> m_fastOutputRows;
674   internal::TensorIntDivisor<Index> m_fastOutputCols;
675   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
676 
677   const TensorEvaluator<ArgType, Device> m_impl;
678 };
679 
680 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
681           typename ArgType, typename Device, typename Scalar, typename Index,
682           typename nocontract_t, typename contract_t, int Side, int packet_size,
683           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
684 class TensorContractionSubMapper<
685     Scalar, Index, Side,
686     TensorEvaluator<const TensorReshapingOp<NewDimension,
687                                             const TensorVolumePatchOp<
688                                                 Planes, Rows, Cols, ArgType> >,
689                     Device>,
690     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
691     inner_dim_reordered, Alignment> {
692  public:
693   typedef typename packet_traits<Scalar>::type Packet;
694   typedef typename packet_traits<Scalar>::half HalfPacket;
695 
696   typedef TensorContractionInputMapper<
697       Scalar, Index, Side,
698       TensorEvaluator<const TensorReshapingOp<
699                           NewDimension, const TensorVolumePatchOp<
700                                             Planes, Rows, Cols, ArgType> >,
701                       Device>,
702       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
703       inner_dim_reordered, Alignment>
704       ParentMapper;
705   typedef TensorContractionSubMapper<
706       Scalar, Index, Side,
707       TensorEvaluator<const TensorReshapingOp<
708                           NewDimension, const TensorVolumePatchOp<
709                                             Planes, Rows, Cols, ArgType> >,
710                       Device>,
711       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
712       inner_dim_reordered, Alignment>
713       Self;
714   typedef Self LinearMapper;
715 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)716   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
717       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
718       : m_base_mapper(base_mapper),
719         m_depth_offset(vert_offset),
720         m_col_offset(horiz_offset) {
721     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
722                                      m_colIndex, m_otherIndex);
723   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)724   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
725       const Self& base_mapper, Index vert_offset, Index horiz_offset)
726       : m_base_mapper(base_mapper.m_base_mapper),
727         m_depth_offset(vert_offset + base_mapper.m_depth_offset),
728         m_col_offset(horiz_offset + base_mapper.m_col_offset) {
729     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
730                                      m_colIndex, m_otherIndex);
731   }
operator()732   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
733     return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
734                                    m_colIndex, m_otherIndex);
735   }
operator()736   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
737                                                           Index j) const {
738     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
739   }
740 
loadPacket(Index i)741   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
742     return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
743                                     m_rowIndex, m_colIndex, m_otherIndex);
744   }
loadPacket(Index i,Index j)745   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
746                                                           Index j) const {
747     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
748                                                         j + m_col_offset);
749   }
750 
751   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)752   loadCoeffStandard(Index i) const {
753     return m_base_mapper.loadCoeffStandard(
754         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
755   }
756 
loadPacketFast(Index i)757   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
758     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
759                                         m_rowIndex, m_colIndex, m_otherIndex);
760   }
761   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)762   loadPacketStandard(Index i) const {
763     return m_base_mapper.loadPacketStandard(
764         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
765   }
766   template <typename Packet>
aligned(Index)767   EIGEN_DEVICE_FUNC bool aligned(Index) const {
768     return false;
769   }
770 
771   EIGEN_DEVICE_FUNC
nonStandardPatches()772   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
773     return m_base_mapper.nonStandardPatches();
774   }
775 
776   // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
777   // plane and depth index respectively that fits into the peeled_k elements
778   // starting at m_depth_offset.
779 
780   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)781   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
782     const Index max_col =
783         fastPatchColStride().divide(m_depth_offset + peeled_k);
784     return std::min<Index>(1 + max_col, patchCols());
785   }
786 
787   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)788   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
789                                    const Index col) const {
790     const Index max_row = fastPatchRowStride().divide(
791         m_depth_offset + peeled_k - col * patchColStride());
792     return std::min<Index>(1 + max_row, patchRows());
793   }
794 
795   EIGEN_DEVICE_FUNC
maxPlane(const Index peeled_k,const Index col,const Index row)796   EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
797                                      const Index row) const {
798     const Index max_plane = fastPatchPlaneStride().divide(
799         m_depth_offset + peeled_k - col * patchColStride() -
800         row * patchRowStride());
801     return std::min<Index>(1 + max_plane, patchPlanes());
802   }
803 
804   // MaxDepth uses only the remaining number of elements in the peeled_k.
805   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)806   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
807                                      const Index start_depth) const {
808     return std::min<Index>(start_depth + num_elements, patchDepth());
809   }
810 
811   // Every register matters in this code, so sometimes to prevent register
812   // spilling, instead of the variable that you would expect to see, we use
813   // another one, that is guaranteed to have the same value. E.g. patch depth is
814   // always the same as input depth, and it's also the same as input plane
815   // stride. Bunch of other parameters have similar relations.
816 
817   typedef internal::TensorIntDivisor<Index> IndexDivisor;
818 
819   EIGEN_DEVICE_FUNC
patchDepth()820   EIGEN_ALWAYS_INLINE Index patchDepth() const {
821     eigen_assert(m_base_mapper.m_patch_depth ==
822                      m_base_mapper.m_planeInputStride &&
823                  "Patch depth must be equal to plane input stride.");
824     return m_base_mapper.m_planeInputStride;
825   }
826 
827   EIGEN_DEVICE_FUNC
patchPlanes()828   EIGEN_ALWAYS_INLINE Index patchPlanes() const {
829     eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
830                  "Patch planes must be equal to row stride.");
831     return m_base_mapper.m_rowStride;
832   }
833   EIGEN_DEVICE_FUNC
patchRows()834   EIGEN_ALWAYS_INLINE Index patchRows() const {
835     return m_base_mapper.m_patch_rows;
836   }
837   EIGEN_DEVICE_FUNC
patchCols()838   EIGEN_ALWAYS_INLINE Index patchCols() const {
839     return m_base_mapper.m_patch_cols;
840   }
841 
842   EIGEN_DEVICE_FUNC
patchPlaneStride()843   EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
844     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
845                  "Patch depth must be equal to patch plane stride.");
846     return patchDepth();
847   }
848   EIGEN_DEVICE_FUNC
patchRowStride()849   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
850     return m_base_mapper.m_patch_row_stride;
851   }
852   EIGEN_DEVICE_FUNC
patchColStride()853   EIGEN_ALWAYS_INLINE Index patchColStride() const {
854     return m_base_mapper.m_patch_col_stride;
855   }
856 
857   EIGEN_DEVICE_FUNC
fastPatchPlaneStride()858   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
859     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
860                  "Patch depth must be equal to patch plane stride.");
861     return m_base_mapper.m_fastDimZero;  // patch_depth
862   }
863   EIGEN_DEVICE_FUNC
fastPatchRowStride()864   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
865     return m_base_mapper.m_fastPatchRowStride;
866   }
867   EIGEN_DEVICE_FUNC
fastPatchColStride()868   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
869     return m_base_mapper.m_fastPatchColStride;
870   }
871 
872   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)873   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
874                                              const Index baseIndex) const {
875     const Index inputIndex = depth + baseIndex;
876     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
877   }
878   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)879   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
880                                             const Index baseIndex) const {
881     const Index inputIndex = depth + baseIndex;
882     return m_base_mapper.m_impl.coeff(inputIndex);
883   }
884 
885   EIGEN_DEVICE_FUNC
padPlane(const Index plane)886   EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
887     const Index p = m_planeIndex + plane;
888     return p < 0 || p >= m_base_mapper.m_inputPlanes;
889   }
890   EIGEN_DEVICE_FUNC
padRow(const Index row)891   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
892     const Index r = m_rowIndex + row;
893     return r < 0 || r >= m_base_mapper.m_inputRows;
894   }
895   EIGEN_DEVICE_FUNC
padCol(const Index col)896   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
897     const Index c = m_colIndex + col;
898     return c < 0 || c >= m_base_mapper.m_inputCols;
899   }
900   EIGEN_DEVICE_FUNC
baseIndex(const Index plane,const Index row,const Index col)901   EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
902                                       const Index col) const {
903     const Index p = m_planeIndex + plane;
904     const Index r = m_rowIndex + row;
905     const Index c = m_colIndex + col;
906     return p * m_base_mapper.m_planeInputStride +
907            r * m_base_mapper.m_rowInputStride +
908            c * m_base_mapper.m_colInputStride + m_otherIndex;
909   }
910 
911   EIGEN_DEVICE_FUNC
planeOffset()912   EIGEN_ALWAYS_INLINE Index planeOffset() const {
913     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
914     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
915     const Index rowOffset =
916         (patchOffset - colOffset * m_base_mapper.m_colStride) /
917         m_base_mapper.m_fastRowStride;
918     const Index planeOffset = patchOffset -
919                               colOffset * m_base_mapper.m_colStride -
920                               rowOffset * m_base_mapper.m_rowStride;
921     return planeOffset;
922   }
923 
924   EIGEN_DEVICE_FUNC
rowOffset()925   EIGEN_ALWAYS_INLINE Index rowOffset() const {
926     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
927     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
928     const Index rowOffset =
929         (patchOffset - colOffset * m_base_mapper.m_colStride) /
930         m_base_mapper.m_fastRowStride;
931     return rowOffset;
932   }
933 
934   EIGEN_DEVICE_FUNC
colOffset()935   EIGEN_ALWAYS_INLINE Index colOffset() const {
936     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
937     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
938     return colOffset;
939   }
940 
941   EIGEN_DEVICE_FUNC
depthOffset()942   EIGEN_ALWAYS_INLINE Index depthOffset() const {
943     return m_depth_offset % patchDepth();
944   }
945 
946   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)947   getLinearMapper(Index i, Index j) const {
948     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
949   }
950 
951  private:
952   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
953                                      // performs better in benchmarks.
954 
955   Index m_depth_offset;  // First row in the input matrix
956   Index m_col_offset;    // First col in the input matrix
957 
958   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
959   // indices for the first element in a patch specified by col_offset
960   // (see computeBaseIndices(...) for details).
961   Index m_planeIndex;
962   Index m_rowIndex;
963   Index m_colIndex;
964   Index m_otherIndex;
965 };
966 
967 // Arrange a block of the right input matrix (in our case it's always a "virtual
968 // matrix" constructed from extracted volume patches) in contiguous memory.
969 //
970 // Given column major input (A0 beside A1 in memory):
971 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
972 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
973 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
974 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
975 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
976 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
977 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
978 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
979 // A8 ...
980 // ...
981 //
982 // *) A, B, C, ... - patches extracted from the original input.
983 // *) A0, A1, A2 ... - values from the same patch at different offsets.
984 //
985 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
986 // A0 B0 C0 D0 A1 B1 C1 D1 ...
987 // E0 F0 G0 H0 E1 F1 G1 H1 ...
988 // ...
989 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
990 //
991 // This traversal order must be the same as in default gemm_pack_rhs defined in
992 // GeneralBlockPanelKernel.h.
993 //
994 // *) nr - number of registers along the 'n' dimension.
995 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
996 //    Multiplication" paper.
997 //
998 // TODO(ezhulenev): Add support for squeezing reads along two innermost
999 // dimensions (see eigen_spatial_convolutions).
1000 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1001           typename ArgType, typename Device, typename Scalar, typename Index,
1002           typename nocontract_t, typename contract_t, int packet_size,
1003           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1004           int nr>
1005 struct gemm_pack_rhs<
1006     Scalar, Index,
1007     TensorContractionSubMapper<
1008         Scalar, Index, Rhs,
1009         TensorEvaluator<const TensorReshapingOp<
1010                             NewDimension, const TensorVolumePatchOp<
1011                                               Planes, Rows, Cols, ArgType> >,
1012                         Device>,
1013         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1014         inner_dim_reordered, Alignment>,
1015     nr, ColMajor, false, false> {
1016   typedef TensorContractionSubMapper<
1017       Scalar, Index, Rhs,
1018       TensorEvaluator<const TensorReshapingOp<
1019                           NewDimension, const TensorVolumePatchOp<
1020                                             Planes, Rows, Cols, ArgType> >,
1021                       Device>,
1022       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1023       inner_dim_reordered, Alignment>
1024       SubMapper;
1025 
1026   typedef SubMapper DataMapper;
1027   typedef typename packet_traits<Scalar>::type Packet;
1028 
1029   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1030 
1031   EIGEN_DEVICE_FUNC
1032   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1033                                     Index depth, Index cols, Index stride = 0,
1034                                     Index offset = 0) const {
1035     eigen_assert(stride == 0);
1036     eigen_assert(offset == 0);
1037 
1038     const Index packet_cols4 = (cols / 4) * 4;
1039     const Index peeled_k = (depth / packet_size) * packet_size;
1040     const bool non_standard_patches = rhs.nonStandardPatches();
1041 
1042     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1043       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1044       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1045       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1046       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1047 
1048       Index k = 0;
1049       if ((packet_size % 4) == 0 && !non_standard_patches) {
1050         // FAST PATH:
1051         // Iterate over patch columns, rows and planes if we know that a single
1052         // packet do not span across multiple planes, rows or columns.
1053         if ((rhs.patchDepth() % packet_size) == 0) {
1054           const Index start_col = rhs.colOffset();
1055           const Index max_col = rhs.maxCol(peeled_k);
1056 
1057           for (Index c = start_col; c < max_col; ++c) {
1058             eigen_assert(k <= peeled_k);
1059 
1060             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1061             const Index max_row = rhs.maxRow(peeled_k, c);
1062 
1063             const bool pad_col0 = dm0.padCol(c);
1064             const bool pad_col1 = dm1.padCol(c);
1065             const bool pad_col2 = dm2.padCol(c);
1066             const bool pad_col3 = dm3.padCol(c);
1067 
1068             for (Index r = start_row; r < max_row; ++r) {
1069               eigen_assert(k <= peeled_k);
1070 
1071               const Index start_plane = ((c == start_col) && (r == start_row))
1072                                             ? rhs.planeOffset()
1073                                             : 0;
1074               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1075 
1076               const bool pad_row0 = pad_col0 || dm0.padRow(r);
1077               const bool pad_row1 = pad_col1 || dm1.padRow(r);
1078               const bool pad_row2 = pad_col2 || dm2.padRow(r);
1079               const bool pad_row3 = pad_col3 || dm3.padRow(r);
1080 
1081               for (Index p = start_plane; p < max_plane; ++p) {
1082                 eigen_assert(k <= peeled_k);
1083 
1084                 const bool pad0 = pad_row0 || dm0.padPlane(p);
1085                 const bool pad1 = pad_row1 || dm1.padPlane(p);
1086                 const bool pad2 = pad_row2 || dm2.padPlane(p);
1087                 const bool pad3 = pad_row3 || dm3.padPlane(p);
1088 
1089                 const Index idx0 = dm0.baseIndex(p, r, c);
1090                 const Index idx1 = dm1.baseIndex(p, r, c);
1091                 const Index idx2 = dm2.baseIndex(p, r, c);
1092                 const Index idx3 = dm3.baseIndex(p, r, c);
1093 
1094                 const Index start_depth =
1095                     ((c == start_col) && (r == start_row) && (p == start_plane))
1096                         ? rhs.depthOffset()
1097                         : 0;
1098                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1099                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1100 
1101                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1102                   eigen_assert(k < peeled_k);
1103                   PacketBlock<Packet, 4> kernel;
1104                   kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1105                                           : rhs.packetNoPadding(d, idx0);
1106                   kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1107                                           : rhs.packetNoPadding(d, idx1);
1108                   kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1109                                           : rhs.packetNoPadding(d, idx2);
1110                   kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1111                                           : rhs.packetNoPadding(d, idx3);
1112                   ptranspose(kernel);
1113                   pstoreu(block + 0 * packet_size, kernel.packet[0]);
1114                   pstoreu(block + 1 * packet_size, kernel.packet[1]);
1115                   pstoreu(block + 2 * packet_size, kernel.packet[2]);
1116                   pstoreu(block + 3 * packet_size, kernel.packet[3]);
1117                   block += 4 * packet_size;
1118                   k += packet_size;
1119                 }
1120               }
1121             }
1122           }
1123 
1124           // The loop above should fill peeled_k elements.
1125           eigen_assert(peeled_k == k);
1126 
1127         } else {
1128           // Packet can span multiple planes, rows or columns, so we have to go
1129           // though the slower "standard" path.
1130           for (; k < peeled_k; k += packet_size) {
1131             PacketBlock<Packet, 4> kernel;
1132             kernel.packet[0] = dm0.loadPacketStandard(k);
1133             kernel.packet[1] = dm1.loadPacketStandard(k);
1134             kernel.packet[2] = dm2.loadPacketStandard(k);
1135             kernel.packet[3] = dm3.loadPacketStandard(k);
1136             ptranspose(kernel);
1137             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1138             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1139             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1140             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1141             block += 4 * packet_size;
1142           }
1143         }
1144       }
1145 
1146       // Copy the remaining coefficients of the column block after the peeled_k.
1147       if (!non_standard_patches) {
1148         for (; k < depth; k++) {
1149           block[0] = dm0.loadCoeffStandard(k);
1150           block[1] = dm1.loadCoeffStandard(k);
1151           block[2] = dm2.loadCoeffStandard(k);
1152           block[3] = dm3.loadCoeffStandard(k);
1153           block += 4;
1154         }
1155       } else {
1156         for (; k < depth; k++) {
1157           block[0] = dm0(k);
1158           block[1] = dm1(k);
1159           block[2] = dm2(k);
1160           block[3] = dm3(k);
1161           block += 4;
1162         }
1163       }
1164     }
1165 
1166     // Copy the remaining columns one at a time (nr==1).
1167     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1168       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1169       for (Index k = 0; k < depth; k++) {
1170         *block = dm0(k);
1171         block += 1;
1172       }
1173     }
1174   }
1175 };
1176 
1177 // Template specialization for packet_size = 2. We must special-case packet
1178 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1179 //
1180 // TODO(ezhulenev): Add support for squeezing reads along two innermost
1181 // dimensions (see eigen_spatial_convolutions).
1182 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1183           typename ArgType, typename Device, typename Scalar, typename Index,
1184           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1185           bool inner_dim_reordered, int Alignment, int nr>
1186 struct gemm_pack_rhs<
1187     Scalar, Index,
1188     TensorContractionSubMapper<
1189         Scalar, Index, Rhs,
1190         TensorEvaluator<const TensorReshapingOp<
1191                             NewDimension, const TensorVolumePatchOp<
1192                                               Planes, Rows, Cols, ArgType> >,
1193                         Device>,
1194         nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1195         inner_dim_reordered, Alignment>,
1196     nr, ColMajor, false, false> {
1197   typedef TensorContractionSubMapper<
1198       Scalar, Index, Rhs,
1199       TensorEvaluator<const TensorReshapingOp<
1200                           NewDimension, const TensorVolumePatchOp<
1201                                             Planes, Rows, Cols, ArgType> >,
1202                       Device>,
1203       nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1204       inner_dim_reordered, Alignment>
1205       SubMapper;
1206   typedef SubMapper DataMapper;
1207   typedef typename packet_traits<Scalar>::type Packet;
1208 
1209   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1210 
1211   EIGEN_DEVICE_FUNC
1212   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1213                                     Index depth, Index cols, Index stride = 0,
1214                                     Index offset = 0) const {
1215     eigen_assert(stride == 0);
1216     eigen_assert(offset == 0);
1217 
1218     const int packet_size = 2;
1219 
1220     const Index packet_cols4 = (cols / 4) * 4;
1221     const Index peeled_k = (depth / packet_size) * packet_size;
1222     const bool non_standard_patches = rhs.nonStandardPatches();
1223 
1224     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1225       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1226       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1227       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1228       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1229 
1230       Index k = 0;
1231       if (!non_standard_patches) {
1232         // FAST PATH:
1233         // Iterate over patch columns, rows and planes if we know that a single
1234         // packet do not span across multiple planes, rows or columns.
1235         if ((rhs.patchDepth() % packet_size) == 0) {
1236           const Index start_col = rhs.colOffset();
1237           const Index max_col = rhs.maxCol(peeled_k);
1238 
1239           for (Index c = start_col; c < max_col; ++c) {
1240             eigen_assert(k <= peeled_k);
1241 
1242             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1243             const Index max_row = rhs.maxRow(peeled_k, c);
1244 
1245             const bool pad_col0 = dm0.padCol(c);
1246             const bool pad_col1 = dm1.padCol(c);
1247             const bool pad_col2 = dm2.padCol(c);
1248             const bool pad_col3 = dm3.padCol(c);
1249 
1250             for (Index r = start_row; r < max_row; ++r) {
1251               eigen_assert(k <= peeled_k);
1252 
1253               const Index start_plane = ((c == start_col) && (r == start_row))
1254                                             ? rhs.planeOffset()
1255                                             : 0;
1256               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1257 
1258               const bool pad_row0 = dm0.padRow(r);
1259               const bool pad_row1 = dm1.padRow(r);
1260               const bool pad_row2 = dm2.padRow(r);
1261               const bool pad_row3 = dm3.padRow(r);
1262 
1263               for (Index p = start_plane; p < max_plane; ++p) {
1264                 eigen_assert(k <= peeled_k);
1265 
1266                 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
1267                 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
1268                 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
1269                 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
1270 
1271                 const Index idx0 = dm0.baseIndex(p, r, c);
1272                 const Index idx1 = dm1.baseIndex(p, r, c);
1273                 const Index idx2 = dm2.baseIndex(p, r, c);
1274                 const Index idx3 = dm3.baseIndex(p, r, c);
1275 
1276                 const Index start_depth =
1277                     ((c == start_col) && (r == start_row) && (p == start_plane))
1278                         ? rhs.depthOffset()
1279                         : 0;
1280                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1281                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1282 
1283                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1284                   eigen_assert(k < peeled_k);
1285                   PacketBlock<Packet, 2> kernel0;
1286                   PacketBlock<Packet, 2> kernel1;
1287                   kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1288                                            : rhs.packetNoPadding(d, idx0);
1289                   kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1290                                            : rhs.packetNoPadding(d, idx1);
1291                   kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1292                                            : rhs.packetNoPadding(d, idx2);
1293                   kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1294                                            : rhs.packetNoPadding(d, idx3);
1295                   ptranspose(kernel0);
1296                   ptranspose(kernel1);
1297                   pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1298                   pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1299                   pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1300                   pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1301                   block += 4 * packet_size;
1302                   k += packet_size;
1303                 }
1304               }
1305             }
1306           }
1307 
1308           // The loop above should fill peeled_k elements.
1309           eigen_assert(peeled_k == k);
1310 
1311         } else {
1312           for (; k < peeled_k; k += packet_size) {
1313             PacketBlock<Packet, 2> kernel0;
1314             PacketBlock<Packet, 2> kernel1;
1315             kernel0.packet[0] = dm0.loadPacketStandard(k);
1316             kernel0.packet[1] = dm1.loadPacketStandard(k);
1317             kernel1.packet[0] = dm2.loadPacketStandard(k);
1318             kernel1.packet[1] = dm3.loadPacketStandard(k);
1319             ptranspose(kernel0);
1320             ptranspose(kernel1);
1321             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1322             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1323             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1324             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1325             block += 4 * packet_size;
1326           }
1327         }
1328       }
1329 
1330       // Copy the remaining coefficients of the column block after the peeled_k.
1331       if (!rhs.nonStandardPatches()) {
1332         for (; k < depth; k++) {
1333           block[0] = dm0.loadCoeffStandard(k);
1334           block[1] = dm1.loadCoeffStandard(k);
1335           block[2] = dm2.loadCoeffStandard(k);
1336           block[3] = dm3.loadCoeffStandard(k);
1337           block += 4;
1338         }
1339       } else {
1340         for (; k < depth; k++) {
1341           block[0] = dm0(k);
1342           block[1] = dm1(k);
1343           block[2] = dm2(k);
1344           block[3] = dm3(k);
1345           block += 4;
1346         }
1347       }
1348     }
1349 
1350     // Copy the remaining columns one at a time (nr==1).
1351     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1352       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1353       for (Index k = 0; k < depth; k++) {
1354         *block = dm0(k);
1355         block += 1;
1356       }
1357     }
1358   }
1359 };
1360 
1361 // Special case for non-vectorized types such as float16 (packet_size = 1).
1362 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1363           typename ArgType, typename Device, typename Scalar, typename Index,
1364           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1365           bool inner_dim_reordered, int Alignment, int nr>
1366 struct gemm_pack_rhs<
1367     Scalar, Index,
1368     TensorContractionSubMapper<
1369         Scalar, Index, Rhs,
1370         TensorEvaluator<const TensorReshapingOp<
1371                             NewDimension, const TensorVolumePatchOp<
1372                                               Planes, Rows, Cols, ArgType> >,
1373                         Device>,
1374         nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
1375         inner_dim_reordered, Alignment>,
1376     nr, ColMajor, false, false> {
1377   typedef TensorContractionSubMapper<
1378       Scalar, Index, Rhs,
1379       TensorEvaluator<const TensorReshapingOp<
1380                           NewDimension, const TensorVolumePatchOp<
1381                                             Planes, Rows, Cols, ArgType> >,
1382                       Device>,
1383       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1384       Alignment>
1385       SubMapper;
1386   typedef SubMapper DataMapper;
1387 
1388   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1389 
1390   EIGEN_DEVICE_FUNC
1391   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1392                                     Index depth, Index cols, Index stride = 0,
1393                                     Index offset = 0) const {
1394     eigen_assert(stride == 0);
1395     eigen_assert(offset == 0);
1396 
1397     const Index packet_cols4 = (cols / 4) * 4;
1398 
1399     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1400       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1401       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1402       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1403       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1404 
1405       if (!rhs.nonStandardPatches()) {
1406         for (Index k = 0; k < depth; k++) {
1407           block[0] = dm0.loadCoeffStandard(k);
1408           block[1] = dm1.loadCoeffStandard(k);
1409           block[2] = dm2.loadCoeffStandard(k);
1410           block[3] = dm3.loadCoeffStandard(k);
1411           block += 4;
1412         }
1413       } else {
1414         for (Index k = 0; k < depth; k++) {
1415           block[0] = dm0(k);
1416           block[1] = dm1(k);
1417           block[2] = dm2(k);
1418           block[3] = dm3(k);
1419           block += 4;
1420         }
1421       }
1422     }
1423 
1424     // Copy the remaining columns one at a time (nr==1).
1425     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1426       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1427       for (Index k = 0; k < depth; k++) {
1428         *block = dm0(k);
1429         block += 1;
1430       }
1431     }
1432   }
1433 };
1434 
1435 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1436 // Pack a block of the right input matrix (in our case it's always a "virtual
1437 // matrix" constructed from extracted image patches) in contiguous block in
1438 // column-major storage order. Knowing the properties of the original patch op
1439 // we can do it more efficient than the default gemm_pack_colmajor_block.
1440 //
1441 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports
1442 // squeezing reads along the 2 innermost dimensions, add it here if needed.
1443 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1444           typename ArgType, typename Device, typename Scalar,
1445           typename StorageIndex, typename nocontract_t, typename contract_t,
1446           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
1447           int Alignment>
1448 struct gemm_pack_colmajor_block<
1449     Scalar, StorageIndex,
1450     TensorContractionSubMapper<
1451         Scalar, StorageIndex, Rhs,
1452         TensorEvaluator<const TensorReshapingOp<
1453                             NewDimension, const TensorVolumePatchOp<
1454                                               Planes, Rows, Cols, ArgType> >,
1455                         Device>,
1456         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1457         inner_dim_reordered, Alignment>,
1458     ColMajor> {
1459   typedef TensorContractionSubMapper<
1460       Scalar, StorageIndex, Rhs,
1461       TensorEvaluator<const TensorReshapingOp<
1462                           NewDimension, const TensorVolumePatchOp<
1463                                             Planes, Rows, Cols, ArgType> >,
1464                       Device>,
1465       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1466       inner_dim_reordered, Alignment>
1467       SubMapper;
1468 
1469   typedef SubMapper DataMapper;
1470   typedef typename packet_traits<Scalar>::type Packet;
1471 
1472   EIGEN_DONT_INLINE
1473   void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
1474                   StorageIndex cols) {
1475     const bool standard_patches = !rhs.nonStandardPatches();
1476 
1477     if (standard_patches && rhs.patchDepth() % packet_size == 0) {
1478       packStandardPatches<true>(block, rhs, rows, cols);
1479 
1480     } else if (standard_patches) {
1481       packStandardPatches<false>(block, rhs, rows, cols);
1482 
1483     } else {
1484       // With non-standard patches we don't do any vectorized loads.
1485       // TODO(ezhulenev): It doesn't look like that we should completely give up
1486       // on packets. Make this code path faster!
1487       for (StorageIndex col = 0; col < cols; ++col) {
1488         SubMapper lm = rhs.getLinearMapper(0, col);
1489         for (StorageIndex i = 0; i < rows; ++i) {
1490           *block = lm(i);
1491           ++block;
1492         }
1493       }
1494     }
1495   }
1496 
1497  private:
1498   // Pack standard volume patches:
1499   //
1500   // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have
1501   //   depth dimension size to be a multiple of packet size, so we can skip all
1502   //   non vectorized loads and checks.
1503   //
1504   template <bool patch_depth_is_multiple_of_packet_size>
1505   EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
1506                                                const DataMapper& rhs,
1507                                                StorageIndex rows,
1508                                                StorageIndex cols) {
1509     eigen_assert(!rhs.nonStandardPatches());
1510 
1511     // Give vectorized_rows the name used in all other gemm_pack_rhs above.
1512     const Index peeled_k = (rows / packet_size) * packet_size;
1513 
1514     const Index start_col = rhs.colOffset();
1515     const Index max_col = rhs.maxCol(peeled_k);
1516 
1517     for (StorageIndex col = 0; col < cols; ++col) {
1518       SubMapper lm = rhs.getLinearMapper(0, col);
1519 
1520       Index k = 0;
1521       for (Index c = start_col; c < max_col; ++c) {
1522         eigen_assert(k <= peeled_k);
1523 
1524         const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1525         const Index max_row = rhs.maxRow(peeled_k, c);
1526         const bool pad_col = lm.padCol(c);
1527 
1528         for (Index r = start_row; r < max_row; ++r) {
1529           eigen_assert(k <= peeled_k);
1530 
1531           const Index start_plane =
1532               ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0;
1533           const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1534           const bool pad_row = pad_col || lm.padRow(r);
1535 
1536           for (Index p = start_plane; p < max_plane; ++p) {
1537             eigen_assert(k <= peeled_k);
1538 
1539             const Index start_depth =
1540                 ((c == start_col) && (r == start_row) && (p == start_plane))
1541                     ? rhs.depthOffset()
1542                     : 0;
1543             const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1544 
1545             const bool pad = pad_col || pad_row || lm.padPlane(p);
1546             const Index base_idx = lm.baseIndex(p, r, c);
1547 
1548             if (patch_depth_is_multiple_of_packet_size)
1549               eigen_assert((max_depth - start_depth) % packet_size == 0);
1550 
1551             // If patch depth is a multiple of packet size, it's guaranteed that
1552             // we can process all values in depth dimension with packets.
1553             const Index max_vectorized_depth =
1554                 patch_depth_is_multiple_of_packet_size
1555                     ? max_depth
1556                     : max_depth - packet_size;
1557 
1558             Index d = start_depth;
1559 
1560             // 1. Process depth dimension with vectorized instructions.
1561             for (; d < max_vectorized_depth; d += packet_size) {
1562               eigen_assert(k < peeled_k);
1563               const Packet packet = pad ? pset1<Packet>(Scalar(0))
1564                                         : rhs.packetNoPadding(d, base_idx);
1565               internal::pstoreu(block, packet);
1566               block += packet_size;
1567               k += packet_size;
1568             }
1569 
1570             // 2. Finish with coefficients.
1571             if (!patch_depth_is_multiple_of_packet_size) {
1572               for (; d < max_depth; d++) {
1573                 eigen_assert(k < peeled_k);
1574                 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx);
1575                 ++block;
1576                 ++k;
1577               }
1578             }
1579           }
1580         }
1581       }
1582 
1583       // The loop above should fill peeled_k elements.
1584       eigen_assert(peeled_k == k);
1585 
1586       // Fill remaining elements using loadCoeffStandard.
1587       for (; k < rows; ++k) {
1588         *block = lm.loadCoeffStandard(k);
1589         ++block;
1590       }
1591     }
1592   }
1593 };
1594 #endif  // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1595 
1596 }  // namespace internal
1597 
1598 /** CuboidConvolution
1599  * \ingroup CXX11_NeuralNetworks_Module
1600  *
1601  * \brief Applies a 3D convolution over a multichannel input voxel block.
1602  *
1603  * The input parameter is expected to be a tensor with a rank of 4 or more
1604  * (channels, depth, height, width, and optionally others).
1605  * The kernel parameter is expected to be a 5D tensor (filters, channels,
1606  * kernel_depth, kernel_height, kernel_width).
1607  * The result can be assigned to a tensor of rank equal to the rank of the
1608  * input. The dimensions of the result will be filters, depth, height, width
1609  * (and others if applicable).
1610  *
1611  * The input and kernel have to be in the same layout, and both row-major and
1612  * col-major are supported. The shapes given above are for col-major layout.
1613  * For row-major, all dimensions should be reversed.
1614  *
1615  * It is possible to swap the order of the depth, width, and height dimensions
1616  * provided that the same order is used in the input, the kernel, and the
1617  * output.
1618  */
1619 template <typename Input, typename Kernel>
1620 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1621     internal::traits<Input>::Layout == ColMajor,
1622     TensorReshapingOp<
1623         const DSizes<typename internal::traits<Input>::Index,
1624                      internal::traits<Input>::NumDimensions>,
1625         const TensorContractionOp<
1626             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1627             const TensorReshapingOp<
1628                 const DSizes<typename internal::traits<Input>::Index, 2>,
1629                 const Kernel>,
1630             const TensorReshapingOp<
1631                 const DSizes<typename internal::traits<Input>::Index, 2>,
1632                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1633                                           const Input> > > >,
1634     TensorReshapingOp<
1635         const DSizes<typename internal::traits<Input>::Index,
1636                      internal::traits<Input>::NumDimensions>,
1637         const TensorContractionOp<
1638             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1639             const TensorReshapingOp<
1640                 const DSizes<typename internal::traits<Input>::Index, 2>,
1641                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1642                                           const Input> >,
1643             const TensorReshapingOp<
1644                 const DSizes<typename internal::traits<Input>::Index, 2>,
1645                 const Kernel> > > >::type
1646 CuboidConvolution(const Input& input, const Kernel& kernel,
1647                   const Index stridePlanes = 1, const Index strideRows = 1,
1648                   const Index strideCols = 1,
1649                   const PaddingType padding_type = PADDING_SAME) {
1650   typedef typename internal::traits<Input>::Index TensorIndex;
1651   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1652                    internal::traits<Input>::NumDimensions,
1653                    internal::traits<Input>::Layout, TensorIndex> >
1654       in(input);
1655   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1656                    internal::traits<Kernel>::NumDimensions,
1657                    internal::traits<Kernel>::Layout, TensorIndex> >
1658       kern(kernel);
1659 
1660   EIGEN_STATIC_ASSERT(
1661       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1662       YOU_MADE_A_PROGRAMMING_MISTAKE);
1663   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1664   static const int NumDims = internal::traits<Input>::NumDimensions;
1665 
1666   // Number of filters to apply. This is the same as the output depth of the
1667   // result.
1668   const TensorIndex kernelFilters =
1669       isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
1670   const TensorIndex kernelChannels =
1671       isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
1672 
1673   // Spatial size of the kernel.
1674   const TensorIndex kernelPlanes =
1675       isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
1676   const TensorIndex kernelRows =
1677       isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
1678   const TensorIndex kernelCols =
1679       isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
1680 
1681   if (isColMajor) {
1682     eigen_assert(kernelChannels == in.dimension(0));
1683   } else {
1684     eigen_assert(kernelChannels == in.dimension(NumDims - 1));
1685   }
1686 
1687   const TensorIndex inputPlanes =
1688       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1689   const TensorIndex inputRows =
1690       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1691   const TensorIndex inputCols =
1692       isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
1693 
1694   TensorIndex out_planes;
1695   TensorIndex out_height;
1696   TensorIndex out_width;
1697   switch (padding_type) {
1698     case PADDING_VALID:
1699       out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
1700                                 static_cast<TensorIndex>(stridePlanes));
1701       out_height = Eigen::divup(inputRows - kernelRows + 1,
1702                                 static_cast<TensorIndex>(strideRows));
1703       out_width = Eigen::divup(inputCols - kernelCols + 1,
1704                                static_cast<TensorIndex>(strideCols));
1705       break;
1706     case PADDING_SAME:
1707       out_planes =
1708           Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
1709       out_height =
1710           Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
1711       out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
1712       break;
1713     default:
1714       out_planes = 0;
1715       out_height = 0;
1716       out_width = 0;
1717       eigen_assert(false && "unexpected padding");
1718   }
1719 
1720   DSizes<TensorIndex, 2> kernel_dims;
1721   if (isColMajor) {
1722     kernel_dims[0] = kernelFilters;
1723     kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1724   } else {
1725     kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1726     kernel_dims[1] = kernelFilters;
1727   }
1728 
1729   // Molds the output of the patch extraction result into a 2D tensor:
1730   // - the first dimension (dims[0]): the patch values to be multiplied with the
1731   // kernels
1732   // - the second dimension (dims[1]): everything else
1733   DSizes<TensorIndex, 2> pre_contract_dims;
1734   if (isColMajor) {
1735     pre_contract_dims[0] =
1736         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1737     pre_contract_dims[1] = out_planes * out_height * out_width;
1738     for (int i = 4; i < NumDims; ++i) {
1739       pre_contract_dims[1] *= in.dimension(i);
1740     }
1741   } else {
1742     pre_contract_dims[1] =
1743         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1744     pre_contract_dims[0] = out_planes * out_height * out_width;
1745     for (int i = 0; i < NumDims - 4; ++i) {
1746       pre_contract_dims[0] *= in.dimension(i);
1747     }
1748   }
1749 
1750   array<IndexPair<TensorIndex>, 1> contract_dims;
1751   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1752 
1753   // Molds the output of the contraction into the shape expected by the user
1754   // (assuming ColMajor):
1755   // - 1st dim: kernel filters
1756   // - 2nd dim: output depth
1757   // - 3nd dim: output height
1758   // - 4rd dim: output width
1759   // - 5th dim and beyond: everything else including batch size
1760   DSizes<TensorIndex, NumDims> post_contract_dims;
1761   if (isColMajor) {
1762     post_contract_dims[0] = kernelFilters;
1763     post_contract_dims[1] = out_planes;
1764     post_contract_dims[2] = out_height;
1765     post_contract_dims[3] = out_width;
1766     for (int i = 4; i < NumDims; ++i) {
1767       post_contract_dims[i] = in.dimension(i);
1768     }
1769   } else {
1770     post_contract_dims[NumDims - 1] = kernelFilters;
1771     post_contract_dims[NumDims - 2] = out_planes;
1772     post_contract_dims[NumDims - 3] = out_height;
1773     post_contract_dims[NumDims - 4] = out_width;
1774     for (int i = 0; i < NumDims - 4; ++i) {
1775       post_contract_dims[i] = in.dimension(i);
1776     }
1777   }
1778 
1779   return choose(
1780       Cond<internal::traits<Input>::Layout == ColMajor>(),
1781       kernel.reshape(kernel_dims)
1782           .contract(input
1783                         .extract_volume_patches(
1784                             kernelPlanes, kernelRows, kernelCols, stridePlanes,
1785                             strideRows, strideCols, padding_type)
1786                         .reshape(pre_contract_dims),
1787                     contract_dims)
1788           .reshape(post_contract_dims),
1789       input
1790           .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
1791                                   stridePlanes, strideRows, strideCols,
1792                                   padding_type)
1793           .reshape(pre_contract_dims)
1794           .contract(kernel.reshape(kernel_dims), contract_dims)
1795           .reshape(post_contract_dims));
1796 }
1797 
1798 }  // end namespace Eigen
1799 
1800 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
1801