• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
18 
19 // Note this header is used in both TF and TFLite.
20 namespace Eigen {
21 
22 namespace internal {
23 
24 // WARNING: Most of the code here implicitly assumes that the matrix is in
25 // ColMajor layout. This is guaranteed by the tensor contraction (see
26 // TensorContraction.h).
27 //
28 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
29 // We don't want to actually extract image patches and reshape the result into
30 // a matrix (this involves allocating huge extra memory), so the patch
31 // extraction and reshape operations are implicit.
32 //
33 // TensorContractionInputMapper takes a matrix index and returns the coefficient
34 // (or the packet) of the "virtual tensor", that would be at that index if we
35 // were to actually reshape the result of patch extraction.
36 //
37 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
38 // at the given vertical and horizontal offsets.
39 //
40 // "Virtual matrix" dimensions:
41 //   *0: kernelChannels * kernelRows * kernelCols;
42 //    1: out_height * out_width; * OTHERS (e.g batches, etc...)
43 //
44 // *) extracted patches are continuous in memory (innermost dimension assuming
45 //    col major layout)
46 //
47 // With this dimensions:
48 //   row - offset within a single patch (in code: patchId)
49 //   col - index of the extracted patch (in code: patchIndex)
50 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
51 //
52 // TODO(ezhulenev): Consolidate this part of the code with the image patch
53 // extraction code since they are both very similar.
54 
55 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
56           typename Device, typename Scalar_, typename Index,
57           typename nocontract_t, typename contract_t, int Side, int packet_size,
58           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
59 class TensorContractionInputMapper<
60     Scalar_, Index, Side,
61     TensorEvaluator<
62         const TensorReshapingOp<NewDimension,
63                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
64         Device>,
65     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
66     inner_dim_reordered, Alignment> {
67  public:
68   typedef Scalar_ Scalar;
69 
70   typedef TensorContractionInputMapper<
71       Scalar, Index, Side,
72       TensorEvaluator<
73           const TensorReshapingOp<
74               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
75           Device>,
76       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
77       inner_dim_reordered, Alignment>
78       Self;
79 
80   typedef TensorContractionSubMapper<
81       Scalar, Index, Side,
82       TensorEvaluator<
83           const TensorReshapingOp<
84               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
85           Device>,
86       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
87       inner_dim_reordered, Alignment>
88       SubMapper;
89 
90   typedef SubMapper VectorMapper;
91   typedef SubMapper LinearMapper;
92   typedef typename packet_traits<Scalar>::type Packet;
93 
94   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)95   TensorContractionInputMapper(
96       const TensorEvaluator<
97           const TensorReshapingOp<
98               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
99           Device>& tensor,
100       const nocontract_t&, const nocontract_t&, const contract_t&,
101       const contract_t&)
102       : m_impl(tensor.impl().impl()) {
103     Index patch_rows;
104     Index patch_depth;
105     if (internal::traits<ArgType>::Layout == ColMajor) {
106       patch_depth = tensor.impl().dimensions()[0];
107       patch_rows = tensor.impl().dimensions()[1];
108       m_patch_cols = tensor.impl().dimensions()[2];
109       m_num_patches = tensor.impl().dimensions()[3];
110     } else {
111       const size_t NumDims = tensor.impl().dimensions().size();
112       patch_depth = tensor.impl().dimensions()[NumDims - 1];
113       patch_rows = tensor.impl().dimensions()[NumDims - 2];
114       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
115       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
116     }
117 
118     // Strides for navigating through the single patch.
119     m_patch_row_stride = patch_depth;
120     m_patch_col_stride = patch_rows * m_patch_row_stride;
121 
122     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
123     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
124 
125     m_colStride = patch_rows;
126 
127     m_outputRows = tensor.impl().outputRows();
128     m_row_strides = tensor.impl().userRowStride();
129     m_col_strides = tensor.impl().userColStride();
130 
131     m_in_row_strides = tensor.impl().userInRowStride();
132     m_in_col_strides = tensor.impl().userInColStride();
133 
134     if (internal::traits<ArgType>::Layout == ColMajor) {
135       m_inputRows = tensor.impl().impl().dimensions()[1];
136       m_inputCols = tensor.impl().impl().dimensions()[2];
137     } else {
138       const int NumDims = tensor.impl().impl().dimensions().size();
139       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
140       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
141     }
142 
143     m_rowInputStride = patch_depth;
144     m_colInputStride = patch_depth * m_inputRows;
145     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
146 
147     m_rowPaddingTop = tensor.impl().rowPaddingTop();
148     m_colPaddingLeft = tensor.impl().colPaddingLeft();
149 
150     m_fastPatchRowStride =
151         internal::TensorIntDivisor<Index>(m_patch_row_stride);
152     m_fastPatchColStride =
153         internal::TensorIntDivisor<Index>(m_patch_col_stride);
154     m_fastInputRowStride =
155         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
156     m_fastInputColStride =
157         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
158     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
159     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
160     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
161     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
162   }
163 
164   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)165   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
166       : m_impl(base_mapper.m_impl) {
167     m_patch_cols = base_mapper.m_patch_cols;
168     m_num_patches = base_mapper.m_num_patches;
169 
170     m_patch_row_stride = base_mapper.m_patch_row_stride;
171     m_patch_col_stride = base_mapper.m_patch_col_stride;
172 
173     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
174     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
175 
176     m_colStride = base_mapper.m_colStride;
177 
178     m_rowInputStride = base_mapper.m_rowInputStride;
179     m_colInputStride = base_mapper.m_colInputStride;
180     m_patchInputStride = base_mapper.m_patchInputStride;
181 
182     m_inputRows = base_mapper.m_inputRows;
183     m_inputCols = base_mapper.m_inputCols;
184 
185     m_outputRows = base_mapper.m_outputRows;
186     m_row_strides = base_mapper.m_row_strides;
187     m_col_strides = base_mapper.m_col_strides;
188 
189     m_in_row_strides = base_mapper.m_in_row_strides;
190     m_in_col_strides = base_mapper.m_in_col_strides;
191 
192     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
193     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
194 
195     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
196     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
197     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
198     m_fastInputColStride = base_mapper.m_fastInputColStride;
199     m_fastNumPatches = base_mapper.m_fastNumPatches;
200     m_fastColStride = base_mapper.m_fastColStride;
201     m_fastOutputRows = base_mapper.m_fastOutputRows;
202     m_fastDimZero = base_mapper.m_fastDimZero;
203   }
204 
205   // If true, turns off some optimizations for loading packets since the image
206   // patches are "non-standard" such as there are non-trivial strides or
207   // inflations in the input.
208   EIGEN_DEVICE_FUNC
nonStandardPatches()209   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
210     return m_in_row_strides != 1 || m_in_col_strides != 1 ||
211            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
212   }
213 
214   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)215   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
216     return SubMapper(*this, i, j);
217   }
218 
219   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)220   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
221     return LinearMapper(*this, i, j);
222   }
223 
224   EIGEN_DEVICE_FUNC
operator()225   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
226     Index rowIndex, colIndex, otherIndex;
227     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
228     return loadCoeff(row, rowIndex, colIndex, otherIndex);
229   }
230 
231   // Load the coefficient at the patchIndex location instead of the usual
232   // m_rowIndex,
233   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
234   // EIGEN_DEVICE_FUNC
235   EIGEN_DEVICE_FUNC
operator()236   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
237     Index rowIndex, colIndex, otherIndex;
238     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
239     return loadCoeff(row, rowIndex, colIndex, otherIndex);
240   }
241 
242   EIGEN_DEVICE_FUNC
loadPacket(Index row)243   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
244     Index rowIndex, colIndex, otherIndex;
245     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
246     return loadPacket(row, rowIndex, colIndex, otherIndex);
247   }
248 
249   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
250   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
251   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)252   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
253     Index rowIndex, colIndex, otherIndex;
254     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
255     return loadPacket(row, rowIndex, colIndex, otherIndex);
256   }
257 
258   EIGEN_DEVICE_FUNC
impl()259   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
260     return m_impl;
261   }
262 
263   EIGEN_DEVICE_FUNC
patchDepth()264   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
265   EIGEN_DEVICE_FUNC
patchRows()266   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
267   EIGEN_DEVICE_FUNC
patchCols()268   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
269 
270  private:
271   friend class TensorContractionSubMapper<
272       Scalar, Index, Side,
273       TensorEvaluator<
274           const TensorReshapingOp<
275               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
276           Device>,
277       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
278       inner_dim_reordered, Alignment>;
279 
280   // Load coefficient from a patch specified by the "within patch offset"
281   // (patchId) and the precomputed indices of the first element of the patch.
282   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)283   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
284                                        Index colIndex, Index otherIndex) const {
285     // Find the offset of the element wrt the location of the first element.
286     const Index patchOffset = patchId / m_fastDimZero;
287 
288     const Index colOffset = patchOffset / m_fastColStride;
289     const Index inputCol = colIndex + colOffset * m_in_col_strides;
290     const Index origInputCol =
291         (m_patch_col_inflate_strides == 1)
292             ? inputCol
293             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
294 
295     const Index rowOffset = patchOffset - colOffset * m_colStride;
296     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
297     const Index origInputRow =
298         (m_patch_row_inflate_strides == 1)
299             ? inputRow
300             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
301     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
302         origInputRow >= m_inputRows ||
303         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
304         (inputRow != origInputRow * m_patch_row_inflate_strides)) {
305       return Scalar(0);
306     }
307     const Index depth = patchId - patchOffset * patchDepth();
308     const Index inputIndex = depth + origInputRow * m_rowInputStride +
309                              origInputCol * m_colInputStride + otherIndex;
310     return m_impl.coeff(inputIndex);
311   }
312 
313   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
314   // and `in_strides` equal to 1 (template specialization without templates).
315   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)316   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
317                                                Index colIndex,
318                                                Index otherIndex) const {
319     eigen_assert(!nonStandardPatches());
320 
321     // Find the offset of the element wrt the location of the first element.
322     const Index patchOffset = patchId / m_fastDimZero;
323     const Index colOffset = patchOffset / m_fastColStride;
324     const Index rowOffset = patchOffset - colOffset * m_colStride;
325     const Index inputCol = colIndex + colOffset;
326     const Index inputRow = rowIndex + rowOffset;
327     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
328         inputRow >= m_inputRows) {
329       return Scalar(0);
330     }
331     const Index depth = patchId - patchOffset * patchDepth();
332     const Index inputIndex = depth + inputRow * m_rowInputStride +
333                              inputCol * m_colInputStride + otherIndex;
334     return m_impl.coeff(inputIndex);
335   }
336 
337   // Load packet from a patch specified by the "within patch offset"
338   // (patchId) and the precomputed indices of the first element of the patch.
339   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)340   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
341                                         Index colIndex,
342                                         Index otherIndex) const {
343     const Index packetSize = internal::unpacket_traits<Packet>::size;
344     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
345     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
346 
347     if (nonStandardPatches()) {
348       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
349     }
350     return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex);
351   }
352 
353   EIGEN_DEVICE_FUNC
loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)354   EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex,
355                                                 Index colIndex,
356                                                 Index otherIndex) const {
357     const Index packetSize = internal::unpacket_traits<Packet>::size;
358     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
359     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
360 
361     eigen_assert(!nonStandardPatches());
362 
363     if ((patchDepth() % packetSize) == 0) {
364       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
365     } else {
366       // Offsets and input calculation here are identical to
367       // loadCoeffStandard(...), but repeated twice.
368 
369       const Index patchOffsets[2] = {
370           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
371 
372       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
373                                    patchOffsets[1] / m_fastColStride};
374       const Index inputCols[2] = {colIndex + colOffsets[0],
375                                   colIndex + colOffsets[1]};
376       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
377         // all zeros
378         return internal::pset1<Packet>(Scalar(0));
379       }
380 
381       if (inputCols[0] == inputCols[1]) {
382         const Index rowOffsets[2] = {
383             patchOffsets[0] - colOffsets[0] * m_colStride,
384             patchOffsets[1] - colOffsets[1] * m_colStride};
385         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
386         const Index inputRows[2] = {rowIndex + rowOffsets[0],
387                                     rowIndex + rowOffsets[1]};
388 
389         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
390           // all zeros
391           return internal::pset1<Packet>(Scalar(0));
392         }
393 
394         if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
395           // no padding
396           const Index depth = patchId - patchOffsets[0] * patchDepth();
397           const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
398                                    inputCols[0] * m_colInputStride + otherIndex;
399           return m_impl.template packet<Unaligned>(inputIndex);
400         }
401       }
402     }
403     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
404   }
405 
406   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)407   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex,
408                                             Index colIndex,
409                                             Index otherIndex) const {
410     const Index packetSize = internal::unpacket_traits<Packet>::size;
411     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
412     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
413 
414     eigen_assert(!nonStandardPatches());
415     eigen_assert((patchDepth() % packetSize) == 0);
416     // Find the offset of the element wrt the location of the first element.
417     const Index patchOffset = patchId / m_fastDimZero;
418     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
419 
420     const Index colOffset = patchOffset / m_fastColStride;
421     const Index rowOffset = patchOffset - colOffset * m_colStride;
422     const Index inputCol = colIndex + colOffset;
423     const Index inputRow = rowIndex + rowOffset;
424     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
425         inputRow >= m_inputRows) {
426       // all zeros
427       return internal::pset1<Packet>(Scalar(0));
428     }
429     // no padding
430     const Index depth = patchId - patchOffset * patchDepth();
431     const Index inputIndex = depth + inputRow * m_rowInputStride +
432                              inputCol * m_colInputStride + otherIndex;
433     return m_impl.template packet<Unaligned>(inputIndex);
434   }
435 
packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)436   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(
437       Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
438     const int packetSize = internal::unpacket_traits<Packet>::size;
439     EIGEN_ALIGN_MAX
440     typename internal::remove_const<Scalar>::type values[packetSize];
441     for (int i = 0; i < packetSize; ++i) {
442       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
443     }
444     Packet rslt = internal::pload<Packet>(values);
445     return rslt;
446   }
447 
computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)448   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
449       Index patchIndex, Index& rowIndex, Index& colIndex,
450       Index& otherIndex) const {
451     const size_t NumInputDims = array_size<
452         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
453     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
454     const Index patch2DIndex = (NumInputDims == 3)
455                                    ? patchIndex
456                                    : (patchIndex - otherIndex * m_num_patches);
457     otherIndex *= m_patchInputStride;
458     colIndex = patch2DIndex / m_fastOutputRows;
459     rowIndex = patch2DIndex - colIndex * m_outputRows;
460     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
461     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
462   }
463 
464   Index m_patch_cols;   // number of columns in the patch
465   Index m_num_patches;  // number of patches to extract.
466 
467   // Strides for navigating through the single patch.
468   Index m_patch_row_stride;
469   Index m_patch_col_stride;
470   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
471   internal::TensorIntDivisor<Index> m_fastPatchColStride;
472 
473   Index m_patch_row_inflate_strides;  // the strides for row inflation in the
474                                       // image patch
475   Index m_patch_col_inflate_strides;  // the strides for col inflation in the
476                                       // image patch
477   // Fast representation of inflation strides.
478   internal::TensorIntDivisor<Index> m_fastInputRowStride;
479   internal::TensorIntDivisor<Index> m_fastInputColStride;
480 
481   Index m_otherStride;
482   Index m_colStride;
483   internal::TensorIntDivisor<Index> m_fastNumPatches;
484   internal::TensorIntDivisor<Index> m_fastColStride;
485 
486   Index m_rowInputStride;    // row stride in the input tensor
487   Index m_colInputStride;    // col stride in the input tensor
488   Index m_patchInputStride;  // patch stride in the input tensor
489 
490   Index m_inputRows;  // Number of rows in the input tensor
491   Index m_inputCols;  // Number of cols in the input tensor
492 
493   Index m_outputRows;  // Number of patch rows
494 
495   Index m_row_strides;  // User specified row stride
496   Index m_col_strides;  // User specified col stride
497 
498   Index m_in_row_strides;  // User specified input row stride
499   Index m_in_col_strides;  // User specified input col stride
500 
501   Index m_rowPaddingTop;   // Row padding
502   Index m_colPaddingLeft;  // Column padding
503 
504   internal::TensorIntDivisor<Index> m_fastOutputRows;
505   internal::TensorIntDivisor<Index> m_fastDimZero;
506 
507   const TensorEvaluator<ArgType, Device> m_impl;
508 };
509 
510 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
511           typename Device, typename Scalar, typename Index,
512           typename nocontract_t, typename contract_t, int Side, int packet_size,
513           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
514 class TensorContractionSubMapper<
515     Scalar, Index, Side,
516     TensorEvaluator<
517         const TensorReshapingOp<NewDimension,
518                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
519         Device>,
520     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
521     inner_dim_reordered, Alignment> {
522  public:
523   typedef typename packet_traits<Scalar>::type Packet;
524   typedef typename packet_traits<Scalar>::half HalfPacket;
525 
526   typedef TensorContractionInputMapper<
527       Scalar, Index, Side,
528       TensorEvaluator<
529           const TensorReshapingOp<
530               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
531           Device>,
532       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
533       inner_dim_reordered, Alignment>
534       ParentMapper;
535 
536   typedef TensorContractionSubMapper<
537       Scalar, Index, Side,
538       TensorEvaluator<
539           const TensorReshapingOp<
540               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
541           Device>,
542       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
543       inner_dim_reordered, Alignment>
544       Self;
545 
546   typedef Self LinearMapper;
547 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)548   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
549       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
550       : m_depth_offset(vert_offset),
551         m_col_offset(horiz_offset),
552         m_base_mapper(base_mapper) {
553     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
554                                      m_otherIndex);
555   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)556   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
557       const Self& base_mapper, Index vert_offset, Index horiz_offset)
558       : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
559         m_col_offset(horiz_offset + base_mapper.m_col_offset),
560         m_base_mapper(base_mapper.m_base_mapper) {
561     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
562                                      m_otherIndex);
563   }
operator()564   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
565     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex,
566                                    m_otherIndex);
567   }
operator()568   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
569                                                           Index j) const {
570     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
571   }
572 
loadPacket(Index i)573   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
574     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex,
575                                     m_otherIndex);
576   }
loadPacket(Index i,Index j)577   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
578                                                           Index j) const {
579     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
580                                                         j + m_col_offset);
581   }
582   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)583   loadCoeffStandard(Index i) const {
584     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
585                                            m_colIndex, m_otherIndex);
586   }
587 
loadPacketFast(Index i)588   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
589     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex,
590                                         m_colIndex, m_otherIndex);
591   }
592   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)593   loadPacketStandard(Index i) const {
594     return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex,
595                                             m_colIndex, m_otherIndex);
596   }
597   template <typename Packet>
aligned(Index)598   EIGEN_DEVICE_FUNC bool aligned(Index) const {
599     return false;
600   }
601 
602   EIGEN_DEVICE_FUNC
nonStandardPatches()603   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
604     return m_base_mapper.nonStandardPatches();
605   }
606 
607   // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
608   // index respectively that fits into the peeled_k elements starting at
609   // m_depth_offset.
610 
611   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)612   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
613     const Index max_col =
614         (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) /
615         fastPatchColStride();
616     return std::min<Index>(1 + max_col, patchCols());
617   }
618 
619   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)620   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
621                                    const Index col) const {
622     const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) -
623                            col * patchColStride()) /
624                           fastPatchRowStride();
625     return std::min<Index>(1 + max_row, patchRows());
626   }
627 
628   EIGEN_DEVICE_FUNC
maxDepth(const Index peeled_k,const Index col,Index row)629   EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col,
630                                      Index row) const {
631     const Index max_depth = m_depth_offset + peeled_k -  //
632                             col * patchColStride() -     //
633                             row * patchRowStride();
634     return std::min<Index>(max_depth, patchDepth());
635   }
636 
637   // MaxDepth uses only the remaining number of elements in the peeled_k.
638   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)639   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
640                                      const Index start_depth) const {
641     return std::min<Index>(start_depth + num_elements, patchDepth());
642   }
643 
644   // Every register matters in this code, so sometimes to prevent register
645   // spilling, instead of the variable that you would expect to see, we use
646   // another one, that is guaranteed to have the same value. E.g. patch depth is
647   // always the same as input depth, and it's also the same as input row stride.
648   // Bunch of other parameters have similar relations.
649 
650   typedef internal::TensorIntDivisor<Index> IndexDivisor;
651 
652   EIGEN_DEVICE_FUNC
patchDepth()653   EIGEN_ALWAYS_INLINE Index patchDepth() const {
654     return m_base_mapper.m_rowInputStride;
655   }
656   EIGEN_DEVICE_FUNC
patchRows()657   EIGEN_ALWAYS_INLINE Index patchRows() const {
658     return m_base_mapper.m_colStride;
659   }
660   EIGEN_DEVICE_FUNC
patchCols()661   EIGEN_ALWAYS_INLINE Index patchCols() const {
662     return m_base_mapper.m_patch_cols;
663   }
664 
665   EIGEN_DEVICE_FUNC
patchRowStride()666   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
667     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
668                  "Patch depth must be equal to patch row stride.");
669     return patchDepth();
670   }
671   EIGEN_DEVICE_FUNC
patchColStride()672   EIGEN_ALWAYS_INLINE Index patchColStride() const {
673     return m_base_mapper.m_patch_col_stride;
674   }
675 
676   EIGEN_DEVICE_FUNC
fastPatchRowStride()677   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
678     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
679                  "Patch depth must be equal to patch row stride.");
680     return m_base_mapper.m_fastDimZero;  // patch_depth
681   }
682   EIGEN_DEVICE_FUNC
fastPatchColStride()683   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
684     return m_base_mapper.m_fastPatchColStride;
685   }
686 
687   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)688   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
689                                              const Index baseIndex) const {
690     const Index inputIndex = depth + baseIndex;
691     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
692   }
693   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)694   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
695                                             const Index baseIndex) const {
696     const Index inputIndex = depth + baseIndex;
697     return m_base_mapper.m_impl.coeff(inputIndex);
698   }
699 
700   EIGEN_DEVICE_FUNC
padRow(const Index row)701   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
702     const Index r = m_rowIndex + row;
703     return r < 0 || r >= m_base_mapper.m_inputRows;
704   }
705   EIGEN_DEVICE_FUNC
padAnyRow(const Index first_row,const Index last_row)706   EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row,
707                                      const Index last_row) const {
708     return m_rowIndex + first_row < 0 ||
709            m_rowIndex + last_row >= m_base_mapper.m_inputRows;
710   }
711   EIGEN_DEVICE_FUNC
padCol(const Index col)712   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
713     const Index c = m_colIndex + col;
714     return c < 0 || c >= m_base_mapper.m_inputCols;
715   }
716   EIGEN_DEVICE_FUNC
baseIndex(const Index row,const Index col)717   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
718     const Index r = m_rowIndex + row;
719     const Index c = m_colIndex + col;
720     return r * m_base_mapper.m_rowInputStride +
721            c * m_base_mapper.m_colInputStride + m_otherIndex;
722   }
723 
724   EIGEN_DEVICE_FUNC
rowStride()725   EIGEN_ALWAYS_INLINE Index rowStride() const {
726     return m_base_mapper.m_row_strides;
727   }
728   EIGEN_DEVICE_FUNC
colStride()729   EIGEN_ALWAYS_INLINE Index colStride() const {
730     return m_base_mapper.m_col_strides;
731   }
732 
733   EIGEN_DEVICE_FUNC
rowOffset()734   EIGEN_ALWAYS_INLINE Index rowOffset() const {
735     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
736     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
737     return patchOffset - colOffset * m_base_mapper.m_colStride;
738   }
739 
740   EIGEN_DEVICE_FUNC
colOffset()741   EIGEN_ALWAYS_INLINE Index colOffset() const {
742     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
743     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
744     return colOffset;
745   }
746 
747   EIGEN_DEVICE_FUNC
depthOffset()748   EIGEN_ALWAYS_INLINE Index depthOffset() const {
749     return m_depth_offset % patchDepth();
750   }
751 
752   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)753   getLinearMapper(Index i, Index j) const {
754     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
755   }
756 
757  private:
758   Index m_depth_offset;  // First row in the input matrix
759   Index m_col_offset;    // First col in the input matrix
760 
761   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
762   // indices for the first element in a patch specified by col_offset
763   // (see computeBaseIndices(...) for details).
764   Index m_rowIndex;
765   Index m_colIndex;
766   Index m_otherIndex;
767 
768   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
769                                      // performs better in benchmarks.
770 };
771 
772 // Arrange a block of the right input matrix (in our case it's always a "virtual
773 // matrix" constructed from extracted image patches) in contiguous memory.
774 //
775 // Given column major input (A0 beside A1 in memory):
776 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
777 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
778 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
779 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
780 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
781 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
782 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
783 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
784 // A8 ...
785 // ...
786 //
787 // *) A, B, C, ... - patches extracted from the original input.
788 // *) A0, A1, A2 ... - values from the same patch at different offsets.
789 //
790 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
791 // A0 B0 C0 D0 A1 B1 C1 D1 ...
792 // E0 F0 G0 H0 E1 F1 G1 H1 ...
793 // ...
794 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
795 //
796 // This traversal order must be the same as in default gemm_pack_rhs defined in
797 // GeneralBlockPanelKernel.h.
798 //
799 // *) nr - number of registers along the 'n' dimension.
800 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
801 //    Multiplication" paper.
802 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
803           typename Device, typename Scalar, typename Index,
804           typename nocontract_t, typename contract_t, int packet_size,
805           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
806           int nr>
807 struct gemm_pack_rhs<
808     Scalar, Index,
809     TensorContractionSubMapper<
810         Scalar, Index, Rhs,
811         TensorEvaluator<
812             const TensorReshapingOp<
813                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
814             Device>,
815         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
816         inner_dim_reordered, Alignment>,
817     nr, ColMajor, false, false> {
818   typedef TensorContractionSubMapper<
819       Scalar, Index, Rhs,
820       TensorEvaluator<
821           const TensorReshapingOp<
822               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
823           Device>,
824       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
825       inner_dim_reordered, Alignment>
826       SubMapper;
827   typedef SubMapper DataMapper;
828   typedef typename packet_traits<Scalar>::type Packet;
829 
830   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
831 
832   EIGEN_DEVICE_FUNC
833   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
834                                     Index depth, Index cols, Index stride = 0,
835                                     Index offset = 0) const {
836     eigen_assert(stride == 0);
837     eigen_assert(offset == 0);
838 
839     const Index packet_cols4 = (cols / 4) * 4;
840     const Index peeled_k = (depth / packet_size) * packet_size;
841     const bool non_standard_patches = rhs.nonStandardPatches();
842 
843     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
844       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
845       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
846       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
847       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
848 
849       Index k = 0;
850       if ((packet_size % 4) == 0 && !non_standard_patches) {
851         // FAST PATH:
852         // Iterate over patch columns and rows, if we know that a single
853         // packet do not span across multiple rows or columns.
854         if ((rhs.patchDepth() % packet_size) == 0) {
855           const Index start_col = rhs.colOffset();
856           const Index max_col = rhs.maxCol(peeled_k);
857 
858           for (Index c = start_col; c < max_col; ++c) {
859             eigen_assert(k <= peeled_k);
860 
861             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
862             const Index max_row = rhs.maxRow(peeled_k, c);
863 
864             const bool pad_col0 = dm0.padCol(c);
865             const bool pad_col1 = dm1.padCol(c);
866             const bool pad_col2 = dm2.padCol(c);
867             const bool pad_col3 = dm3.padCol(c);
868 
869             // Check if we can squeeze reads along the `row` and `depth`
870             // dimensions (two innermost dimensions).
871             if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
872                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
873                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
874                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
875                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
876               // Compute how many elements we can squeeze read.
877               const Index start_depth =
878                   (c == start_col) ? rhs.depthOffset() : 0;
879 
880               // Upper bound for the number of elements in the depth dimension
881               // that we can squeeze read.
882               const Index squeeze_length =
883                   (max_row - start_row) * rhs.patchDepth() - start_depth;
884 
885               // Do not overshoot beyond the block size.
886               const Index max_depth =
887                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
888               eigen_assert((max_depth - start_depth) % packet_size == 0);
889 
890               const Index idx0 = dm0.baseIndex(start_row, c);
891               const Index idx1 = dm1.baseIndex(start_row, c);
892               const Index idx2 = dm2.baseIndex(start_row, c);
893               const Index idx3 = dm3.baseIndex(start_row, c);
894 
895               for (Index d = start_depth; d < max_depth; d += packet_size) {
896                 eigen_assert(k < peeled_k);
897                 PacketBlock<Packet, 4> kernel;
898                 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
899                 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
900                 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
901                 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
902                 ptranspose(kernel);
903                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
904                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
905                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
906                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
907                 block += 4 * packet_size;
908                 k += packet_size;
909               }
910 
911               // Go to the next column.
912               continue;
913             }
914 
915             // If we can't squeeze reads, process rows one by one.
916             for (Index r = start_row; r < max_row; ++r) {
917               eigen_assert(k <= peeled_k);
918 
919               const bool pad0 = pad_col0 || dm0.padRow(r);
920               const bool pad1 = pad_col1 || dm1.padRow(r);
921               const bool pad2 = pad_col2 || dm2.padRow(r);
922               const bool pad3 = pad_col3 || dm3.padRow(r);
923 
924               const Index idx0 = dm0.baseIndex(r, c);
925               const Index idx1 = dm1.baseIndex(r, c);
926               const Index idx2 = dm2.baseIndex(r, c);
927               const Index idx3 = dm3.baseIndex(r, c);
928 
929               const Index start_depth = ((c == start_col) && (r == start_row))
930                                             ? rhs.depthOffset()
931                                             : 0;
932               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
933               eigen_assert((max_depth - start_depth) % packet_size == 0);
934 
935               for (Index d = start_depth; d < max_depth; d += packet_size) {
936                 eigen_assert(k < peeled_k);
937                 PacketBlock<Packet, 4> kernel;
938                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
939                                         : rhs.packetNoPadding(d, idx0);
940                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
941                                         : rhs.packetNoPadding(d, idx1);
942                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
943                                         : rhs.packetNoPadding(d, idx2);
944                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
945                                         : rhs.packetNoPadding(d, idx3);
946                 ptranspose(kernel);
947                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
948                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
949                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
950                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
951                 block += 4 * packet_size;
952                 k += packet_size;
953               }
954             }
955           }
956 
957           // The loop above should fill peeled_k elements.
958           eigen_assert(peeled_k == k);
959 
960         } else {
961           for (; k < peeled_k; k += packet_size) {
962             PacketBlock<Packet, 4> kernel;
963             kernel.packet[0] = dm0.loadPacketStandard(k);
964             kernel.packet[1] = dm1.loadPacketStandard(k);
965             kernel.packet[2] = dm2.loadPacketStandard(k);
966             kernel.packet[3] = dm3.loadPacketStandard(k);
967             ptranspose(kernel);
968             pstoreu(block + 0 * packet_size, kernel.packet[0]);
969             pstoreu(block + 1 * packet_size, kernel.packet[1]);
970             pstoreu(block + 2 * packet_size, kernel.packet[2]);
971             pstoreu(block + 3 * packet_size, kernel.packet[3]);
972             block += 4 * packet_size;
973           }
974         }
975       }
976 
977       // Copy the remaining coefficients of the column block after the peeled_k.
978       if (!rhs.nonStandardPatches()) {
979         for (; k < depth; k++) {
980           block[0] = dm0.loadCoeffStandard(k);
981           block[1] = dm1.loadCoeffStandard(k);
982           block[2] = dm2.loadCoeffStandard(k);
983           block[3] = dm3.loadCoeffStandard(k);
984           block += 4;
985         }
986       } else {
987         for (; k < depth; k++) {
988           block[0] = dm0(k);
989           block[1] = dm1(k);
990           block[2] = dm2(k);
991           block[3] = dm3(k);
992           block += 4;
993         }
994       }
995     }
996 
997     // copy the remaining columns one at a time (nr==1)
998     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
999       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1000       for (Index k = 0; k < depth; k++) {
1001         *block = dm0(k);
1002         block += 1;
1003       }
1004     }
1005   }
1006 };
1007 
1008 // Template specialization for packet_size = 2. We must special-case packet
1009 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1010 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1011           typename Device, typename Scalar, typename Index,
1012           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1013           bool inner_dim_reordered, int Alignment, int nr>
1014 struct gemm_pack_rhs<
1015     Scalar, Index,
1016     TensorContractionSubMapper<
1017         Scalar, Index, Rhs,
1018         TensorEvaluator<
1019             const TensorReshapingOp<
1020                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1021             Device>,
1022         nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1023         Alignment>,
1024     nr, ColMajor, false, false> {
1025   typedef TensorContractionSubMapper<
1026       Scalar, Index, Rhs,
1027       TensorEvaluator<
1028           const TensorReshapingOp<
1029               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1030           Device>,
1031       nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1032       Alignment>
1033       SubMapper;
1034   typedef SubMapper DataMapper;
1035   typedef typename packet_traits<Scalar>::type Packet;
1036 
1037   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1038 
1039   EIGEN_DEVICE_FUNC
1040   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1041                                     Index depth, Index cols, Index stride = 0,
1042                                     Index offset = 0) const {
1043     eigen_assert(stride == 0);
1044     eigen_assert(offset == 0);
1045 
1046     const int packet_size = 2;
1047     const Index packet_cols4 = (cols / 4) * 4;
1048     const Index peeled_k = (depth / packet_size) * packet_size;
1049     const bool non_standard_patches = rhs.nonStandardPatches();
1050 
1051     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1052       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1053       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1054       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1055       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1056 
1057       Index k = 0;
1058       if (!non_standard_patches) {
1059         // FAST PATH:
1060         // Iterate over patch columns and rows if we know that a single
1061         // packet do not span across multiple rows or columns.
1062         if ((rhs.patchDepth() % packet_size) == 0) {
1063           const Index start_col = rhs.colOffset();
1064           const Index max_col = rhs.maxCol(peeled_k);
1065 
1066           for (Index c = start_col; c < max_col; ++c) {
1067             eigen_assert(k <= peeled_k);
1068 
1069             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1070             const Index max_row = rhs.maxRow(peeled_k, c);
1071 
1072             const bool pad_col0 = dm0.padCol(c);
1073             const bool pad_col1 = dm1.padCol(c);
1074             const bool pad_col2 = dm2.padCol(c);
1075             const bool pad_col3 = dm3.padCol(c);
1076 
1077             // We can squeeze reads along the `row` and `depth` dimensions if
1078             // the row stride is `1`, which means that `row` and `depth`
1079             // dimensions are contiguous (two innermost dimensions).
1080             if (rhs.rowStride() == 1 &&                                //
1081                 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
1082                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
1083                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
1084                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
1085                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1086               // Compute how many elements we can squeeze read.
1087               const Index start_depth =
1088                   (c == start_col) ? rhs.depthOffset() : 0;
1089 
1090               // Upper bound for the number of elements in the depth dimension
1091               // that we can squeeze read.
1092               const Index squeeze_length =
1093                   (max_row - start_row) * rhs.patchDepth() - start_depth;
1094 
1095               // Do not overshoot beyond the block size.
1096               const Index max_depth =
1097                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1098               eigen_assert((max_depth - start_depth) % packet_size == 0);
1099 
1100               const Index idx0 = dm0.baseIndex(start_row, c);
1101               const Index idx1 = dm1.baseIndex(start_row, c);
1102               const Index idx2 = dm2.baseIndex(start_row, c);
1103               const Index idx3 = dm3.baseIndex(start_row, c);
1104 
1105               for (Index d = start_depth; d < max_depth; d += packet_size) {
1106                 PacketBlock<Packet, 2> kernel0;
1107                 PacketBlock<Packet, 2> kernel1;
1108                 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1109                 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1110                 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1111                 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1112                 ptranspose(kernel0);
1113                 ptranspose(kernel1);
1114                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1115                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1116                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1117                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1118                 block += 4 * packet_size;
1119                 k += packet_size;
1120               }
1121 
1122               // Go to the next column.
1123               continue;
1124             }
1125 
1126             // If we can't squeeze reads, process rows one by one.
1127             for (Index r = start_row; r < max_row; ++r) {
1128               eigen_assert(k <= peeled_k);
1129 
1130               const bool pad0 = pad_col0 || dm0.padRow(r);
1131               const bool pad1 = pad_col1 || dm1.padRow(r);
1132               const bool pad2 = pad_col2 || dm2.padRow(r);
1133               const bool pad3 = pad_col3 || dm3.padRow(r);
1134 
1135               const Index idx0 = dm0.baseIndex(r, c);
1136               const Index idx1 = dm1.baseIndex(r, c);
1137               const Index idx2 = dm2.baseIndex(r, c);
1138               const Index idx3 = dm3.baseIndex(r, c);
1139 
1140               const Index start_depth = ((c == start_col) && (r == start_row))
1141                                             ? rhs.depthOffset()
1142                                             : 0;
1143               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1144               eigen_assert((max_depth - start_depth) % packet_size == 0);
1145 
1146               for (Index d = start_depth; d < max_depth; d += packet_size) {
1147                 eigen_assert(k < peeled_k);
1148                 PacketBlock<Packet, 2> kernel0;
1149                 PacketBlock<Packet, 2> kernel1;
1150                 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1151                                          : rhs.packetNoPadding(d, idx0);
1152                 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1153                                          : rhs.packetNoPadding(d, idx1);
1154                 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1155                                          : rhs.packetNoPadding(d, idx2);
1156                 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1157                                          : rhs.packetNoPadding(d, idx3);
1158                 ptranspose(kernel0);
1159                 ptranspose(kernel1);
1160                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1161                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1162                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1163                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1164                 block += 4 * packet_size;
1165                 k += packet_size;
1166               }
1167             }
1168           }
1169 
1170           // The loop above should fill peeled_k elements.
1171           eigen_assert(peeled_k == k);
1172 
1173         } else {
1174           // Packet can span multiple rows or columns, so we have to go
1175           // though the slower "standard" path.
1176           for (; k < peeled_k; k += packet_size) {
1177             PacketBlock<Packet, 2> kernel0;
1178             PacketBlock<Packet, 2> kernel1;
1179             kernel0.packet[0] = dm0.loadPacketStandard(k);
1180             kernel0.packet[1] = dm1.loadPacketStandard(k);
1181             kernel1.packet[0] = dm2.loadPacketStandard(k);
1182             kernel1.packet[1] = dm3.loadPacketStandard(k);
1183             ptranspose(kernel0);
1184             ptranspose(kernel1);
1185             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1186             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1187             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1188             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1189             block += 4 * packet_size;
1190           }
1191         }
1192       }
1193 
1194       // Copy the remaining coefficients of the column block after the peeled_k.
1195       if (!non_standard_patches) {
1196         for (; k < depth; k++) {
1197           block[0] = dm0.loadCoeffStandard(k);
1198           block[1] = dm1.loadCoeffStandard(k);
1199           block[2] = dm2.loadCoeffStandard(k);
1200           block[3] = dm3.loadCoeffStandard(k);
1201           block += 4;
1202         }
1203       } else {
1204         for (; k < depth; k++) {
1205           block[0] = dm0(k);
1206           block[1] = dm1(k);
1207           block[2] = dm2(k);
1208           block[3] = dm3(k);
1209           block += 4;
1210         }
1211       }
1212     }
1213 
1214     // Copy the remaining columns one at a time (nr==1).
1215     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1216       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1217       for (Index k = 0; k < depth; k++) {
1218         *block = dm0(k);
1219         block += 1;
1220       }
1221     }
1222   }
1223 };
1224 
1225 // Special case for non-vectorized types such as float16.
1226 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1227           typename Device, typename Scalar, typename Index,
1228           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1229           bool inner_dim_reordered, int Alignment, int nr>
1230 struct gemm_pack_rhs<
1231     Scalar, Index,
1232     TensorContractionSubMapper<
1233         Scalar, Index, Rhs,
1234         TensorEvaluator<
1235             const TensorReshapingOp<
1236                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1237             Device>,
1238         nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1239         Alignment>,
1240     nr, ColMajor, false, false> {
1241   typedef TensorContractionSubMapper<
1242       Scalar, Index, Rhs,
1243       TensorEvaluator<
1244           const TensorReshapingOp<
1245               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1246           Device>,
1247       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1248       Alignment>
1249       SubMapper;
1250   typedef SubMapper DataMapper;
1251 
1252   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1253 
1254   EIGEN_DEVICE_FUNC
1255   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1256                                     Index depth, Index cols, Index stride = 0,
1257                                     Index offset = 0) const {
1258     eigen_assert(stride == 0);
1259     eigen_assert(offset == 0);
1260 
1261     const Index packet_cols4 = (cols / 4) * 4;
1262 
1263     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1264       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1265       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1266       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1267       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1268 
1269       if (!rhs.nonStandardPatches()) {
1270         for (Index k = 0; k < depth; k++) {
1271           block[0] = dm0.loadCoeffStandard(k);
1272           block[1] = dm1.loadCoeffStandard(k);
1273           block[2] = dm2.loadCoeffStandard(k);
1274           block[3] = dm3.loadCoeffStandard(k);
1275           block += 4;
1276         }
1277       } else {
1278         for (Index k = 0; k < depth; k++) {
1279           block[0] = dm0(k);
1280           block[1] = dm1(k);
1281           block[2] = dm2(k);
1282           block[3] = dm3(k);
1283           block += 4;
1284         }
1285       }
1286     }
1287 
1288     // Copy the remaining columns one at a time (nr==1).
1289     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1290       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1291       for (Index k = 0; k < depth; k++) {
1292         *block = dm0(k);
1293         block += 1;
1294       }
1295     }
1296   }
1297 };
1298 }  // end namespace internal
1299 
1300 /** SpatialConvolution
1301  * \ingroup CXX11_NeuralNetworks_Module
1302  *
1303  * \brief Applies a 2D convolution over a multichannel input image.
1304  *
1305  * The input parameter is expected to be a tensor with a rank of 3 or more
1306  * (channels, height, width, and optionally others)
1307  * The kernel parameter is expected to be a 4D tensor (filters, channels,
1308  * kernel_height, kernel_width)
1309  * The input and the kernel must both be in col-major layout. The result will
1310  * also be in col-major layout.
1311  *
1312  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1313  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1314  * pixels.
1315  *
1316  * The result can be assigned to a tensor of rank equal to the rank of the
1317  * input. The dimensions of the result will be filters, height, width (and
1318  * others if applicable).
1319  *
1320  * It is possible to swap the order of the width and height dimensions provided
1321  * that the same order is used in the input, the kernel, and the output.
1322  *
1323  * It is also possible to add an output kernel to the contraction, output
1324  * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1325  *
1326  */
1327 template <typename Input, typename Kernel,
1328           typename OutputKernel = const NoOpOutputKernel>
1329 EIGEN_DEVICE_FUNC
1330     EIGEN_ALWAYS_INLINE static const typename internal::conditional<
1331         internal::traits<Input>::Layout == ColMajor,
1332         TensorReshapingOp<
1333             const DSizes<typename internal::traits<Input>::Index,
1334                          internal::traits<Input>::NumDimensions>,
1335             const TensorContractionOp<
1336                 const array<IndexPair<typename internal::traits<Input>::Index>,
1337                             1>,
1338                 const TensorReshapingOp<
1339                     const DSizes<typename internal::traits<Input>::Index, 2>,
1340                     const Kernel>,
1341                 const TensorReshapingOp<
1342                     const DSizes<typename internal::traits<Input>::Index, 2>,
1343                     const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1344                 const OutputKernel> >,
1345         TensorReshapingOp<
1346             const DSizes<typename internal::traits<Input>::Index,
1347                          internal::traits<Input>::NumDimensions>,
1348             const TensorContractionOp<
1349                 const array<IndexPair<typename internal::traits<Input>::Index>,
1350                             1>,
1351                 const TensorReshapingOp<
1352                     const DSizes<typename internal::traits<Input>::Index, 2>,
1353                     const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1354                 const TensorReshapingOp<
1355                     const DSizes<typename internal::traits<Input>::Index, 2>,
1356                     const Kernel>,
1357                 const OutputKernel> > >::type
1358     SpatialConvolution(const Input& input, const Kernel& kernel,
1359                        const Index row_stride = 1, const Index col_stride = 1,
1360                        const PaddingType padding_type = PADDING_SAME,
1361                        const Index row_in_stride = 1,
1362                        const Index col_in_stride = 1,
1363                        const OutputKernel& output_kernel = OutputKernel()) {
1364   typedef typename internal::traits<Input>::Index TensorIndex;
1365   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1366                    internal::traits<Input>::NumDimensions,
1367                    internal::traits<Input>::Layout, TensorIndex> >
1368       in(input);
1369   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1370                    internal::traits<Kernel>::NumDimensions,
1371                    internal::traits<Kernel>::Layout, TensorIndex> >
1372       kern(kernel);
1373 
1374   EIGEN_STATIC_ASSERT(
1375       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1376       YOU_MADE_A_PROGRAMMING_MISTAKE)
1377   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1378 
1379   const int NumDims = internal::traits<Input>::NumDimensions;
1380 
1381   // Number of filters to apply. This is the same as the output depth of the
1382   // result
1383   const TensorIndex kernelFilters =
1384       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1385   // Number of channels. This is the same as the input depth.
1386   const TensorIndex kernelChannels =
1387       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1388   const TensorIndex kernelRows =
1389       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1390   const TensorIndex kernelCols =
1391       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1392 
1393   const Index kernelRowsEff =
1394       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1395   const Index kernelColsEff =
1396       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1397 
1398   array<IndexPair<TensorIndex>, 1> contract_dims;
1399   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1400 
1401   const TensorIndex InputRows =
1402       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1403   const TensorIndex InputCols =
1404       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1405 
1406   TensorIndex out_height;
1407   TensorIndex out_width;
1408   switch (padding_type) {
1409     case PADDING_VALID:
1410       out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
1411                                 static_cast<float>(row_stride));
1412       out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
1413                                static_cast<float>(col_stride));
1414       break;
1415     case PADDING_SAME:
1416       out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
1417       out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
1418       break;
1419     default:
1420       // Initialize unused variables to avoid a compiler warning
1421       out_height = 0;
1422       out_width = 0;
1423       eigen_assert(false && "unexpected padding");
1424   }
1425 
1426   // Molds the output of the patch extraction code into a 2d tensor:
1427   // - the first dimension (dims[0]): the patch values to be multiplied with the
1428   // kernels
1429   // - the second dimension (dims[1]): everything else
1430   DSizes<TensorIndex, 2> pre_contract_dims;
1431   if (isColMajor) {
1432     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1433     pre_contract_dims[1] = out_height * out_width;
1434     for (int i = 3; i < NumDims; ++i) {
1435       pre_contract_dims[1] *= in.dimension(i);
1436     }
1437   } else {
1438     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1439     pre_contract_dims[0] = out_height * out_width;
1440     for (int i = 0; i < NumDims - 3; ++i) {
1441       pre_contract_dims[0] *= in.dimension(i);
1442     }
1443   }
1444 
1445   // Molds the output of the contraction into the shape expected by the used
1446   // (assuming this is ColMajor):
1447   // - 1st dim: kernel filters
1448   // - 2nd dim: output height
1449   // - 3rd dim: output width
1450   // - 4th dim and beyond: everything else including batch size
1451   DSizes<TensorIndex, NumDims> post_contract_dims;
1452   if (isColMajor) {
1453     post_contract_dims[0] = kernelFilters;
1454     post_contract_dims[1] = out_height;
1455     post_contract_dims[2] = out_width;
1456     for (int i = 3; i < NumDims; ++i) {
1457       post_contract_dims[i] = in.dimension(i);
1458     }
1459   } else {
1460     post_contract_dims[NumDims - 1] = kernelFilters;
1461     post_contract_dims[NumDims - 2] = out_height;
1462     post_contract_dims[NumDims - 3] = out_width;
1463     for (int i = 0; i < NumDims - 3; ++i) {
1464       post_contract_dims[i] = in.dimension(i);
1465     }
1466   }
1467 
1468   DSizes<TensorIndex, 2> kernel_dims;
1469   if (isColMajor) {
1470     kernel_dims[0] = kernelFilters;
1471     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1472   } else {
1473     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1474     kernel_dims[1] = kernelFilters;
1475   }
1476   return choose(
1477       Cond<internal::traits<Input>::Layout == ColMajor>(),
1478       kernel.reshape(kernel_dims)
1479           .contract(input
1480                         .extract_image_patches(
1481                             kernelRows, kernelCols, row_stride, col_stride,
1482                             row_in_stride, col_in_stride, padding_type)
1483                         .reshape(pre_contract_dims),
1484                     contract_dims, output_kernel)
1485           .reshape(post_contract_dims),
1486       input
1487           .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
1488                                  row_in_stride, col_in_stride, padding_type)
1489           .reshape(pre_contract_dims)
1490           .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1491           .reshape(post_contract_dims));
1492 }
1493 
1494 }  // end namespace Eigen
1495 
1496 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
1497