• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 
4 #ifndef EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H
5 #define EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H
6 
7 namespace Eigen {
8 
9 /** \class TensorVolumePatch
10   * \ingroup CXX11_Tensor_Module
11   *
12   * \brief Patch extraction specialized for processing of volumetric data.
13   * This assumes that the input has a least 4 dimensions ordered as follows:
14   *  - channels
15   *  - planes
16   *  - rows
17   *  - columns
18   *  - (optional) additional dimensions such as time or batch size.
19   * Calling the volume patch code with patch_planes, patch_rows, and patch_cols
20   * is equivalent to calling the regular patch extraction code with parameters
21   * d, patch_planes, patch_rows, patch_cols, and 1 for all the additional
22   * dimensions.
23   */
24 namespace internal {
25 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>
26 struct traits<TensorVolumePatchOp<Planes, Rows, Cols, XprType> > : public traits<XprType>
27 {
28   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
29   typedef traits<XprType> XprTraits;
30   typedef typename XprTraits::StorageKind StorageKind;
31   typedef typename XprTraits::Index Index;
32   typedef typename XprType::Nested Nested;
33   typedef typename remove_reference<Nested>::type _Nested;
34   static const int NumDimensions = XprTraits::NumDimensions + 1;
35   static const int Layout = XprTraits::Layout;
36 };
37 
38 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>
39 struct eval<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, Eigen::Dense>
40 {
41   typedef const TensorVolumePatchOp<Planes, Rows, Cols, XprType>& type;
42 };
43 
44 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>
45 struct nested<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, 1, typename eval<TensorVolumePatchOp<Planes, Rows, Cols, XprType> >::type>
46 {
47   typedef TensorVolumePatchOp<Planes, Rows, Cols, XprType> type;
48 };
49 
50 }  // end namespace internal
51 
52 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType>
53 class TensorVolumePatchOp : public TensorBase<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, ReadOnlyAccessors>
54 {
55   public:
56   typedef typename Eigen::internal::traits<TensorVolumePatchOp>::Scalar Scalar;
57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58   typedef typename XprType::CoeffReturnType CoeffReturnType;
59   typedef typename Eigen::internal::nested<TensorVolumePatchOp>::type Nested;
60   typedef typename Eigen::internal::traits<TensorVolumePatchOp>::StorageKind StorageKind;
61   typedef typename Eigen::internal::traits<TensorVolumePatchOp>::Index Index;
62 
63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorVolumePatchOp(const XprType& expr, DenseIndex patch_planes, DenseIndex patch_rows, DenseIndex patch_cols,
64                                                             DenseIndex plane_strides, DenseIndex row_strides, DenseIndex col_strides,
65                                                             DenseIndex in_plane_strides, DenseIndex in_row_strides, DenseIndex in_col_strides,
66                                                             DenseIndex plane_inflate_strides, DenseIndex row_inflate_strides, DenseIndex col_inflate_strides,
67                                                             PaddingType padding_type, Scalar padding_value)
68       : m_xpr(expr), m_patch_planes(patch_planes), m_patch_rows(patch_rows), m_patch_cols(patch_cols),
69         m_plane_strides(plane_strides), m_row_strides(row_strides), m_col_strides(col_strides),
70         m_in_plane_strides(in_plane_strides), m_in_row_strides(in_row_strides), m_in_col_strides(in_col_strides),
71         m_plane_inflate_strides(plane_inflate_strides), m_row_inflate_strides(row_inflate_strides), m_col_inflate_strides(col_inflate_strides),
72         m_padding_explicit(false), m_padding_top_z(0), m_padding_bottom_z(0), m_padding_top(0), m_padding_bottom(0), m_padding_left(0), m_padding_right(0),
73         m_padding_type(padding_type), m_padding_value(padding_value) {}
74 
75   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorVolumePatchOp(const XprType& expr, DenseIndex patch_planes, DenseIndex patch_rows, DenseIndex patch_cols,
76                                                            DenseIndex plane_strides, DenseIndex row_strides, DenseIndex col_strides,
77                                                            DenseIndex in_plane_strides, DenseIndex in_row_strides, DenseIndex in_col_strides,
78                                                            DenseIndex plane_inflate_strides, DenseIndex row_inflate_strides, DenseIndex col_inflate_strides,
79                                                            DenseIndex padding_top_z, DenseIndex padding_bottom_z,
80                                                            DenseIndex padding_top, DenseIndex padding_bottom,
81                                                            DenseIndex padding_left, DenseIndex padding_right,
82                                                            Scalar padding_value)
83       : m_xpr(expr), m_patch_planes(patch_planes), m_patch_rows(patch_rows), m_patch_cols(patch_cols),
84         m_plane_strides(plane_strides), m_row_strides(row_strides), m_col_strides(col_strides),
85         m_in_plane_strides(in_plane_strides), m_in_row_strides(in_row_strides), m_in_col_strides(in_col_strides),
86         m_plane_inflate_strides(plane_inflate_strides), m_row_inflate_strides(row_inflate_strides), m_col_inflate_strides(col_inflate_strides),
87         m_padding_explicit(true), m_padding_top_z(padding_top_z), m_padding_bottom_z(padding_bottom_z), m_padding_top(padding_top), m_padding_bottom(padding_bottom),
88         m_padding_left(padding_left), m_padding_right(padding_right),
89         m_padding_type(PADDING_VALID), m_padding_value(padding_value) {}
90 
91     EIGEN_DEVICE_FUNC
92     DenseIndex patch_planes() const { return m_patch_planes; }
93     EIGEN_DEVICE_FUNC
94     DenseIndex patch_rows() const { return m_patch_rows; }
95     EIGEN_DEVICE_FUNC
96     DenseIndex patch_cols() const { return m_patch_cols; }
97     EIGEN_DEVICE_FUNC
98     DenseIndex plane_strides() const { return m_plane_strides; }
99     EIGEN_DEVICE_FUNC
100     DenseIndex row_strides() const { return m_row_strides; }
101     EIGEN_DEVICE_FUNC
102     DenseIndex col_strides() const { return m_col_strides; }
103     EIGEN_DEVICE_FUNC
104     DenseIndex in_plane_strides() const { return m_in_plane_strides; }
105     EIGEN_DEVICE_FUNC
106     DenseIndex in_row_strides() const { return m_in_row_strides; }
107     EIGEN_DEVICE_FUNC
108     DenseIndex in_col_strides() const { return m_in_col_strides; }
109     EIGEN_DEVICE_FUNC
110     DenseIndex plane_inflate_strides() const { return m_plane_inflate_strides; }
111     EIGEN_DEVICE_FUNC
112     DenseIndex row_inflate_strides() const { return m_row_inflate_strides; }
113     EIGEN_DEVICE_FUNC
114     DenseIndex col_inflate_strides() const { return m_col_inflate_strides; }
115     EIGEN_DEVICE_FUNC
116     bool padding_explicit() const { return m_padding_explicit; }
117     EIGEN_DEVICE_FUNC
118     DenseIndex padding_top_z() const { return m_padding_top_z; }
119     EIGEN_DEVICE_FUNC
120     DenseIndex padding_bottom_z() const { return m_padding_bottom_z; }
121     EIGEN_DEVICE_FUNC
122     DenseIndex padding_top() const { return m_padding_top; }
123     EIGEN_DEVICE_FUNC
124     DenseIndex padding_bottom() const { return m_padding_bottom; }
125     EIGEN_DEVICE_FUNC
126     DenseIndex padding_left() const { return m_padding_left; }
127     EIGEN_DEVICE_FUNC
128     DenseIndex padding_right() const { return m_padding_right; }
129     EIGEN_DEVICE_FUNC
130     PaddingType padding_type() const { return m_padding_type; }
131     EIGEN_DEVICE_FUNC
132     Scalar padding_value() const { return m_padding_value; }
133 
134     EIGEN_DEVICE_FUNC
135     const typename internal::remove_all<typename XprType::Nested>::type&
136     expression() const { return m_xpr; }
137 
138   protected:
139     typename XprType::Nested m_xpr;
140     const DenseIndex m_patch_planes;
141     const DenseIndex m_patch_rows;
142     const DenseIndex m_patch_cols;
143     const DenseIndex m_plane_strides;
144     const DenseIndex m_row_strides;
145     const DenseIndex m_col_strides;
146     const DenseIndex m_in_plane_strides;
147     const DenseIndex m_in_row_strides;
148     const DenseIndex m_in_col_strides;
149     const DenseIndex m_plane_inflate_strides;
150     const DenseIndex m_row_inflate_strides;
151     const DenseIndex m_col_inflate_strides;
152     const bool m_padding_explicit;
153     const DenseIndex m_padding_top_z;
154     const DenseIndex m_padding_bottom_z;
155     const DenseIndex m_padding_top;
156     const DenseIndex m_padding_bottom;
157     const DenseIndex m_padding_left;
158     const DenseIndex m_padding_right;
159     const PaddingType m_padding_type;
160     const Scalar m_padding_value;
161 };
162 
163 
164 // Eval as rvalue
165 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device>
166 struct TensorEvaluator<const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device>
167 {
168   typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType;
169   typedef typename XprType::Index Index;
170   static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
171   static const int NumDims = NumInputDims + 1;
172   typedef DSizes<Index, NumDims> Dimensions;
173   typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
174   typedef typename XprType::CoeffReturnType CoeffReturnType;
175   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
176   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
177 
178   enum {
179     IsAligned = false,
180     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
181     BlockAccess = false,
182     Layout = TensorEvaluator<ArgType, Device>::Layout,
183     CoordAccess = false,
184     RawAccess = false
185   };
186 
187   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
188       : m_impl(op.expression(), device)
189   {
190     EIGEN_STATIC_ASSERT((NumDims >= 5), YOU_MADE_A_PROGRAMMING_MISTAKE);
191 
192     m_paddingValue = op.padding_value();
193 
194     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
195 
196     // Cache a few variables.
197     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
198       m_inputDepth = input_dims[0];
199       m_inputPlanes = input_dims[1];
200       m_inputRows = input_dims[2];
201       m_inputCols = input_dims[3];
202     } else {
203       m_inputDepth = input_dims[NumInputDims-1];
204       m_inputPlanes = input_dims[NumInputDims-2];
205       m_inputRows = input_dims[NumInputDims-3];
206       m_inputCols = input_dims[NumInputDims-4];
207     }
208 
209     m_plane_strides = op.plane_strides();
210     m_row_strides = op.row_strides();
211     m_col_strides = op.col_strides();
212 
213     // Input strides and effective input/patch size
214     m_in_plane_strides = op.in_plane_strides();
215     m_in_row_strides = op.in_row_strides();
216     m_in_col_strides = op.in_col_strides();
217     m_plane_inflate_strides = op.plane_inflate_strides();
218     m_row_inflate_strides = op.row_inflate_strides();
219     m_col_inflate_strides = op.col_inflate_strides();
220 
221     // The "effective" spatial size after inflating data with zeros.
222     m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1;
223     m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1;
224     m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1;
225     m_patch_planes_eff = op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1);
226     m_patch_rows_eff = op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1);
227     m_patch_cols_eff = op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1);
228 
229     if (op.padding_explicit()) {
230       m_outputPlanes = numext::ceil((m_input_planes_eff + op.padding_top_z() + op.padding_bottom_z() - m_patch_planes_eff + 1.f) / static_cast<float>(m_plane_strides));
231       m_outputRows = numext::ceil((m_input_rows_eff + op.padding_top() + op.padding_bottom() - m_patch_rows_eff + 1.f) / static_cast<float>(m_row_strides));
232       m_outputCols = numext::ceil((m_input_cols_eff + op.padding_left() + op.padding_right() - m_patch_cols_eff + 1.f) / static_cast<float>(m_col_strides));
233       m_planePaddingTop = op.padding_top_z();
234       m_rowPaddingTop = op.padding_top();
235       m_colPaddingLeft = op.padding_left();
236     } else {
237       // Computing padding from the type
238       switch (op.padding_type()) {
239         case PADDING_VALID:
240           m_outputPlanes = numext::ceil((m_input_planes_eff - m_patch_planes_eff + 1.f) / static_cast<float>(m_plane_strides));
241           m_outputRows = numext::ceil((m_input_rows_eff - m_patch_rows_eff + 1.f) / static_cast<float>(m_row_strides));
242           m_outputCols = numext::ceil((m_input_cols_eff - m_patch_cols_eff + 1.f) / static_cast<float>(m_col_strides));
243           m_planePaddingTop = 0;
244           m_rowPaddingTop = 0;
245           m_colPaddingLeft = 0;
246           break;
247         case PADDING_SAME: {
248           m_outputPlanes = numext::ceil(m_input_planes_eff / static_cast<float>(m_plane_strides));
249           m_outputRows = numext::ceil(m_input_rows_eff / static_cast<float>(m_row_strides));
250           m_outputCols = numext::ceil(m_input_cols_eff / static_cast<float>(m_col_strides));
251           const Index dz = m_outputPlanes * m_plane_strides + m_patch_planes_eff - 1 - m_input_planes_eff;
252           const Index dy = m_outputRows * m_row_strides + m_patch_rows_eff - 1 - m_input_rows_eff;
253           const Index dx = m_outputCols * m_col_strides + m_patch_cols_eff - 1 - m_input_cols_eff;
254           m_planePaddingTop = dz - dz / 2;
255           m_rowPaddingTop = dy - dy / 2;
256           m_colPaddingLeft = dx - dx / 2;
257           break;
258         }
259         default:
260           eigen_assert(false && "unexpected padding");
261       }
262     }
263     eigen_assert(m_outputRows > 0);
264     eigen_assert(m_outputCols > 0);
265     eigen_assert(m_outputPlanes > 0);
266 
267     // Dimensions for result of extraction.
268     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
269       // ColMajor
270       // 0: depth
271       // 1: patch_planes
272       // 2: patch_rows
273       // 3: patch_cols
274       // 4: number of patches
275       // 5 and beyond: anything else (such as batch).
276       m_dimensions[0] = input_dims[0];
277       m_dimensions[1] = op.patch_planes();
278       m_dimensions[2] = op.patch_rows();
279       m_dimensions[3] = op.patch_cols();
280       m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols;
281       for (int i = 5; i < NumDims; ++i) {
282         m_dimensions[i] = input_dims[i-1];
283       }
284     } else {
285       // RowMajor
286       // NumDims-1: depth
287       // NumDims-2: patch_planes
288       // NumDims-3: patch_rows
289       // NumDims-4: patch_cols
290       // NumDims-5: number of patches
291       // NumDims-6 and beyond: anything else (such as batch).
292       m_dimensions[NumDims-1] = input_dims[NumInputDims-1];
293       m_dimensions[NumDims-2] = op.patch_planes();
294       m_dimensions[NumDims-3] = op.patch_rows();
295       m_dimensions[NumDims-4] = op.patch_cols();
296       m_dimensions[NumDims-5] = m_outputPlanes * m_outputRows * m_outputCols;
297       for (int i = NumDims-6; i >= 0; --i) {
298         m_dimensions[i] = input_dims[i];
299       }
300     }
301 
302     // Strides for the output tensor.
303     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
304       m_rowStride = m_dimensions[1];
305       m_colStride = m_dimensions[2] * m_rowStride;
306       m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0];
307       m_otherStride = m_patchStride * m_dimensions[4];
308     } else {
309       m_rowStride = m_dimensions[NumDims-2];
310       m_colStride = m_dimensions[NumDims-3] * m_rowStride;
311       m_patchStride = m_colStride * m_dimensions[NumDims-4] * m_dimensions[NumDims-1];
312       m_otherStride = m_patchStride * m_dimensions[NumDims-5];
313     }
314 
315     // Strides for navigating through the input tensor.
316     m_planeInputStride = m_inputDepth;
317     m_rowInputStride = m_inputDepth * m_inputPlanes;
318     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
319     m_otherInputStride = m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
320 
321     m_outputPlanesRows = m_outputPlanes * m_outputRows;
322 
323     // Fast representations of different variables.
324     m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride);
325     m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride);
326     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
327     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
328     m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_row_inflate_strides);
329     m_fastInputColStride = internal::TensorIntDivisor<Index>(m_col_inflate_strides);
330     m_fastInputPlaneStride = internal::TensorIntDivisor<Index>(m_plane_inflate_strides);
331     m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff);
332     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
333     m_fastOutputPlanesRows = internal::TensorIntDivisor<Index>(m_outputPlanesRows);
334 
335     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
336       m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]);
337     } else {
338       m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[NumDims-1]);
339     }
340   }
341 
342   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
343 
344   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
345     m_impl.evalSubExprsIfNeeded(NULL);
346     return true;
347   }
348 
349   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
350     m_impl.cleanup();
351   }
352 
353   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
354   {
355     // Patch index corresponding to the passed in index.
356     const Index patchIndex = index / m_fastPatchStride;
357 
358     // Spatial offset within the patch. This has to be translated into 3D
359     // coordinates within the patch.
360     const Index patchOffset = (index - patchIndex * m_patchStride) / m_fastOutputDepth;
361 
362     // Batch, etc.
363     const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride;
364     const Index patch3DIndex = (NumDims == 5) ? patchIndex : (index - otherIndex * m_otherStride) / m_fastPatchStride;
365 
366     // Calculate column index in the input original tensor.
367     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
368     const Index colOffset = patchOffset / m_fastColStride;
369     const Index inputCol = colIndex * m_col_strides + colOffset * m_in_col_strides - m_colPaddingLeft;
370     const Index origInputCol = (m_col_inflate_strides == 1) ? inputCol : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
371     if (inputCol < 0 || inputCol >= m_input_cols_eff ||
372         ((m_col_inflate_strides != 1) && (inputCol != origInputCol * m_col_inflate_strides))) {
373       return Scalar(m_paddingValue);
374     }
375 
376     // Calculate row index in the original input tensor.
377     const Index rowIndex = (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
378     const Index rowOffset = (patchOffset - colOffset * m_colStride) / m_fastRowStride;
379     const Index inputRow = rowIndex * m_row_strides + rowOffset * m_in_row_strides - m_rowPaddingTop;
380     const Index origInputRow = (m_row_inflate_strides == 1) ? inputRow : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
381     if (inputRow < 0 || inputRow >= m_input_rows_eff ||
382         ((m_row_inflate_strides != 1) && (inputRow != origInputRow * m_row_inflate_strides))) {
383       return Scalar(m_paddingValue);
384     }
385 
386     // Calculate plane index in the original input tensor.
387     const Index planeIndex = (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
388     const Index planeOffset = patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
389     const Index inputPlane = planeIndex * m_plane_strides + planeOffset * m_in_plane_strides - m_planePaddingTop;
390     const Index origInputPlane = (m_plane_inflate_strides == 1) ? inputPlane : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
391     if (inputPlane < 0 || inputPlane >= m_input_planes_eff ||
392         ((m_plane_inflate_strides != 1) && (inputPlane != origInputPlane * m_plane_inflate_strides))) {
393       return Scalar(m_paddingValue);
394     }
395 
396     const int depth_index = static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : NumDims - 1;
397     const Index depth = index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
398 
399     const Index inputIndex = depth +
400         origInputRow * m_rowInputStride +
401         origInputCol * m_colInputStride +
402         origInputPlane * m_planeInputStride +
403         otherIndex * m_otherInputStride;
404 
405     return m_impl.coeff(inputIndex);
406   }
407 
408   template<int LoadMode>
409   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
410   {
411     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
412     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
413 
414     if (m_in_row_strides != 1 || m_in_col_strides != 1 || m_row_inflate_strides != 1 || m_col_inflate_strides != 1 ||
415         m_in_plane_strides != 1 || m_plane_inflate_strides != 1) {
416       return packetWithPossibleZero(index);
417     }
418 
419     const Index indices[2] = {index, index + PacketSize - 1};
420     const Index patchIndex = indices[0] / m_fastPatchStride;
421     if (patchIndex != indices[1] / m_fastPatchStride) {
422       return packetWithPossibleZero(index);
423     }
424     const Index otherIndex = (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride;
425     eigen_assert(otherIndex == indices[1] / m_fastOtherStride);
426 
427     // Find the offset of the element wrt the location of the first element.
428     const Index patchOffsets[2] = {(indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth,
429                                    (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth};
430 
431     const Index patch3DIndex = (NumDims == 5) ? patchIndex : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride;
432     eigen_assert(patch3DIndex == (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride);
433 
434     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
435     const Index colOffsets[2] = {
436       patchOffsets[0] / m_fastColStride,
437       patchOffsets[1] / m_fastColStride};
438 
439     // Calculate col indices in the original input tensor.
440     const Index inputCols[2] = {
441       colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft,
442       colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft};
443     if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) {
444       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
445     }
446 
447     if (inputCols[0] != inputCols[1]) {
448       return packetWithPossibleZero(index);
449     }
450 
451     const Index rowIndex = (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
452     const Index rowOffsets[2] = {
453       (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
454       (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
455     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
456     // Calculate col indices in the original input tensor.
457     const Index inputRows[2] = {
458       rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop,
459       rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop};
460 
461     if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) {
462       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
463     }
464 
465     if (inputRows[0] != inputRows[1]) {
466       return packetWithPossibleZero(index);
467     }
468 
469     const Index planeIndex = (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
470     const Index planeOffsets[2] = {
471       patchOffsets[0] - colOffsets[0] * m_colStride - rowOffsets[0] * m_rowStride,
472       patchOffsets[1] - colOffsets[1] * m_colStride - rowOffsets[1] * m_rowStride};
473     eigen_assert(planeOffsets[0] <= planeOffsets[1]);
474     const Index inputPlanes[2] = {
475       planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop,
476       planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop};
477 
478     if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) {
479       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
480     }
481 
482     if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
483       // no padding
484       const int depth_index = static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : NumDims - 1;
485       const Index depth = index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
486       const Index inputIndex = depth +
487           inputRows[0] * m_rowInputStride +
488           inputCols[0] * m_colInputStride +
489           m_planeInputStride * inputPlanes[0] +
490           otherIndex * m_otherInputStride;
491       return m_impl.template packet<Unaligned>(inputIndex);
492     }
493 
494     return packetWithPossibleZero(index);
495   }
496 
497   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
498   costPerCoeff(bool vectorized) const {
499     const double compute_cost =
500         10 * TensorOpCost::DivCost<Index>() + 21 * TensorOpCost::MulCost<Index>() +
501         8 * TensorOpCost::AddCost<Index>();
502     return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
503   }
504 
505   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
506 
507   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
508 
509   Index planePaddingTop() const { return m_planePaddingTop; }
510   Index rowPaddingTop() const { return m_rowPaddingTop; }
511   Index colPaddingLeft() const { return m_colPaddingLeft; }
512   Index outputPlanes() const { return m_outputPlanes; }
513   Index outputRows() const { return m_outputRows; }
514   Index outputCols() const { return m_outputCols; }
515   Index userPlaneStride() const { return m_plane_strides; }
516   Index userRowStride() const { return m_row_strides; }
517   Index userColStride() const { return m_col_strides; }
518   Index userInPlaneStride() const { return m_in_plane_strides; }
519   Index userInRowStride() const { return m_in_row_strides; }
520   Index userInColStride() const { return m_in_col_strides; }
521   Index planeInflateStride() const { return m_plane_inflate_strides; }
522   Index rowInflateStride() const { return m_row_inflate_strides; }
523   Index colInflateStride() const { return m_col_inflate_strides; }
524 
525  protected:
526   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetWithPossibleZero(Index index) const
527   {
528     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
529     for (int i = 0; i < PacketSize; ++i) {
530       values[i] = coeff(index+i);
531     }
532     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
533     return rslt;
534   }
535 
536   Dimensions m_dimensions;
537 
538   // Parameters passed to the costructor.
539   Index m_plane_strides;
540   Index m_row_strides;
541   Index m_col_strides;
542 
543   Index m_outputPlanes;
544   Index m_outputRows;
545   Index m_outputCols;
546 
547   Index m_planePaddingTop;
548   Index m_rowPaddingTop;
549   Index m_colPaddingLeft;
550 
551   Index m_in_plane_strides;
552   Index m_in_row_strides;
553   Index m_in_col_strides;
554 
555   Index m_plane_inflate_strides;
556   Index m_row_inflate_strides;
557   Index m_col_inflate_strides;
558 
559   // Cached input size.
560   Index m_inputDepth;
561   Index m_inputPlanes;
562   Index m_inputRows;
563   Index m_inputCols;
564 
565   // Other cached variables.
566   Index m_outputPlanesRows;
567 
568   // Effective input/patch post-inflation size.
569   Index m_input_planes_eff;
570   Index m_input_rows_eff;
571   Index m_input_cols_eff;
572   Index m_patch_planes_eff;
573   Index m_patch_rows_eff;
574   Index m_patch_cols_eff;
575 
576   // Strides for the output tensor.
577   Index m_otherStride;
578   Index m_patchStride;
579   Index m_rowStride;
580   Index m_colStride;
581 
582   // Strides for the input tensor.
583   Index m_planeInputStride;
584   Index m_rowInputStride;
585   Index m_colInputStride;
586   Index m_otherInputStride;
587 
588   internal::TensorIntDivisor<Index> m_fastOtherStride;
589   internal::TensorIntDivisor<Index> m_fastPatchStride;
590   internal::TensorIntDivisor<Index> m_fastColStride;
591   internal::TensorIntDivisor<Index> m_fastRowStride;
592   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
593   internal::TensorIntDivisor<Index> m_fastInputRowStride;
594   internal::TensorIntDivisor<Index> m_fastInputColStride;
595   internal::TensorIntDivisor<Index> m_fastInputColsEff;
596   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
597   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
598   internal::TensorIntDivisor<Index> m_fastOutputDepth;
599 
600   Scalar m_paddingValue;
601 
602   TensorEvaluator<ArgType, Device> m_impl;
603 };
604 
605 
606 } // end namespace Eigen
607 
608 #endif // EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H
609