• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 namespace Eigen {
22 
23 // Changes the interpretation of padding in TensorVolumePatchOp to be compatible
24 // with the rest of TensorFlow (odd padding is split so that more padding is put
25 // on the right end of the tensor).
26 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType,
27           typename Device>
28 struct CustomTensorEvaluator {
29   typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType;
30   typedef typename XprType::Index Index;
31   static constexpr int NumInputDims = internal::array_size<
32       typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
33   static constexpr int NumDims = NumInputDims + 1;
34   typedef DSizes<Index, NumDims> Dimensions;
35   typedef
36       typename internal::remove_const<typename XprType::Scalar>::type Scalar;
37   typedef typename XprType::CoeffReturnType CoeffReturnType;
38   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
39   static constexpr Index PacketSize =
40       internal::unpacket_traits<PacketReturnType>::size;
41 
42   enum {
43     IsAligned = false,
44     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
45     BlockAccess = false,
46     PreferBlockAccess = false,
47     Layout = TensorEvaluator<ArgType, Device>::Layout,
48     CoordAccess = NumDims == 6,
49     RawAccess = false
50   };
51 
52   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CustomTensorEvaluatorCustomTensorEvaluator53   CustomTensorEvaluator(const XprType& op, const Device& device)
54       : m_impl(op.expression(), device) {
55     EIGEN_STATIC_ASSERT(NumDims >= 5, YOU_MADE_A_PROGRAMMING_MISTAKE);
56 
57     m_paddingValue = op.padding_value();
58 
59     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims =
60         m_impl.dimensions();
61 
62     // Cache a few variables.
63     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
64       m_inputDepth = input_dims[0];
65       m_inputPlanes = input_dims[1];
66       m_inputRows = input_dims[2];
67       m_inputCols = input_dims[3];
68     } else {
69       m_inputDepth = input_dims[NumInputDims - 1];
70       m_inputPlanes = input_dims[NumInputDims - 2];
71       m_inputRows = input_dims[NumInputDims - 3];
72       m_inputCols = input_dims[NumInputDims - 4];
73     }
74 
75     m_plane_strides = op.plane_strides();
76     m_row_strides = op.row_strides();
77     m_col_strides = op.col_strides();
78 
79     // Input strides and effective input/patch size
80     m_in_plane_strides = op.in_plane_strides();
81     m_in_row_strides = op.in_row_strides();
82     m_in_col_strides = op.in_col_strides();
83     m_plane_inflate_strides = op.plane_inflate_strides();
84     m_row_inflate_strides = op.row_inflate_strides();
85     m_col_inflate_strides = op.col_inflate_strides();
86 
87     // The "effective" spatial size after inflating data with zeros.
88     m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1;
89     m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1;
90     m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1;
91     m_patch_planes_eff =
92         op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1);
93     m_patch_rows_eff =
94         op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1);
95     m_patch_cols_eff =
96         op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1);
97 
98     if (op.padding_explicit()) {
99       m_outputPlanes = Eigen::divup(
100           m_input_planes_eff +
101               static_cast<Index>(op.padding_top_z() + op.padding_bottom_z()) -
102               m_patch_planes_eff + 1,
103           m_plane_strides);
104       m_outputRows = Eigen::divup(
105           m_input_rows_eff +
106               static_cast<Index>(op.padding_top() + op.padding_bottom()) -
107               m_patch_rows_eff + 1,
108           m_row_strides);
109       m_outputCols = Eigen::divup(
110           m_input_cols_eff +
111               static_cast<Index>(op.padding_left() + op.padding_right()) -
112               m_patch_cols_eff + 1,
113           m_col_strides);
114       m_planePaddingTop = op.padding_top_z();
115       m_rowPaddingTop = op.padding_top();
116       m_colPaddingLeft = op.padding_left();
117     } else {
118       // Computing padding from the type
119       switch (op.padding_type()) {
120         case PADDING_VALID:
121           m_outputPlanes = Eigen::divup(
122               m_input_planes_eff - m_patch_planes_eff + 1, m_plane_strides);
123           m_outputRows = Eigen::divup(m_input_rows_eff - m_patch_rows_eff + 1,
124                                       m_row_strides);
125           m_outputCols = Eigen::divup(m_input_cols_eff - m_patch_cols_eff + 1,
126                                       m_col_strides);
127           m_planePaddingTop = 0;
128           m_rowPaddingTop = 0;
129           m_colPaddingLeft = 0;
130           break;
131         case PADDING_SAME: {
132           m_outputPlanes = Eigen::divup(m_input_planes_eff, m_plane_strides);
133           m_outputRows = Eigen::divup(m_input_rows_eff, m_row_strides);
134           m_outputCols = Eigen::divup(m_input_cols_eff, m_col_strides);
135           const Index dz = numext::maxi<DenseIndex>(
136               0, (m_outputPlanes - 1) * m_plane_strides + m_patch_planes_eff -
137                      m_input_planes_eff);
138           const Index dy = numext::maxi<DenseIndex>(
139               0, (m_outputRows - 1) * m_row_strides + m_patch_rows_eff -
140                      m_input_rows_eff);
141           const Index dx = numext::maxi<DenseIndex>(
142               0, (m_outputCols - 1) * m_col_strides + m_patch_cols_eff -
143                      m_input_cols_eff);
144           m_planePaddingTop = dz / 2;
145           m_rowPaddingTop = dy / 2;
146           m_colPaddingLeft = dx / 2;
147           break;
148         }
149         default:
150           eigen_assert(false && "unexpected padding");
151       }
152     }
153     eigen_assert(m_outputRows > 0);
154     eigen_assert(m_outputCols > 0);
155     eigen_assert(m_outputPlanes > 0);
156 
157     // Dimensions for result of extraction.
158     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
159       // ColMajor
160       // 0: depth
161       // 1: patch_planes
162       // 2: patch_rows
163       // 3: patch_cols
164       // 4: number of patches
165       // 5 and beyond: anything else (such as batch).
166       m_dimensions[0] = input_dims[0];
167       m_dimensions[1] = op.patch_planes();
168       m_dimensions[2] = op.patch_rows();
169       m_dimensions[3] = op.patch_cols();
170       m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols;
171       for (int i = 5; i < NumDims; ++i) {
172         m_dimensions[i] = input_dims[i - 1];
173       }
174     } else {
175       // RowMajor
176       // NumDims-1: depth
177       // NumDims-2: patch_planes
178       // NumDims-3: patch_rows
179       // NumDims-4: patch_cols
180       // NumDims-5: number of patches
181       // NumDims-6 and beyond: anything else (such as batch).
182       m_dimensions[NumDims - 1] = input_dims[NumInputDims - 1];
183       m_dimensions[NumDims - 2] = op.patch_planes();
184       m_dimensions[NumDims - 3] = op.patch_rows();
185       m_dimensions[NumDims - 4] = op.patch_cols();
186       m_dimensions[NumDims - 5] = m_outputPlanes * m_outputRows * m_outputCols;
187       for (int i = NumDims - 6; i >= 0; --i) {
188         m_dimensions[i] = input_dims[i];
189       }
190     }
191 
192     // Strides for the output tensor.
193     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
194       m_rowStride = m_dimensions[1];
195       m_colStride = m_dimensions[2] * m_rowStride;
196       m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0];
197       m_otherStride = m_patchStride * m_dimensions[4];
198     } else {
199       m_rowStride = m_dimensions[NumDims - 2];
200       m_colStride = m_dimensions[NumDims - 3] * m_rowStride;
201       m_patchStride =
202           m_colStride * m_dimensions[NumDims - 4] * m_dimensions[NumDims - 1];
203       m_otherStride = m_patchStride * m_dimensions[NumDims - 5];
204     }
205 
206     // Strides for navigating through the input tensor.
207     m_planeInputStride = m_inputDepth;
208     m_rowInputStride = m_inputDepth * m_inputPlanes;
209     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
210     m_otherInputStride =
211         m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
212 
213     m_outputPlanesRows = m_outputPlanes * m_outputRows;
214 
215     // Fast representations of different variables.
216     m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride);
217     m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride);
218     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
219     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
220     m_fastInputRowStride =
221         internal::TensorIntDivisor<Index>(m_row_inflate_strides);
222     m_fastInputColStride =
223         internal::TensorIntDivisor<Index>(m_col_inflate_strides);
224     m_fastInputPlaneStride =
225         internal::TensorIntDivisor<Index>(m_plane_inflate_strides);
226     m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff);
227     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
228     m_fastOutputPlanesRows =
229         internal::TensorIntDivisor<Index>(m_outputPlanesRows);
230 
231     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
232       m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]);
233     } else {
234       m_fastOutputDepth =
235           internal::TensorIntDivisor<Index>(m_dimensions[NumDims - 1]);
236     }
237   }
238 
dimensionsCustomTensorEvaluator239   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
240     return m_dimensions;
241   }
242 
evalSubExprsIfNeededCustomTensorEvaluator243   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(
244       Scalar* /*data*/) {
245     m_impl.evalSubExprsIfNeeded(NULL);
246     return true;
247   }
248 
cleanupCustomTensorEvaluator249   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
250 
251   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType
coeffCustomTensorEvaluator252   coeff(Index index) const {
253     // Patch index corresponding to the passed in index.
254     const Index patchIndex = index / m_fastPatchStride;
255 
256     // Spatial offset within the patch. This has to be translated into 3D
257     // coordinates within the patch.
258     const Index patchOffset =
259         (index - patchIndex * m_patchStride) / m_fastOutputDepth;
260 
261     // Batch, etc.
262     const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride;
263     const Index patch3DIndex =
264         (NumDims == 5)
265             ? patchIndex
266             : (index - otherIndex * m_otherStride) / m_fastPatchStride;
267 
268     // Calculate column index in the input original tensor.
269     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
270     const Index colOffset = patchOffset / m_fastColStride;
271     const Index inputCol = colIndex * m_col_strides +
272                            colOffset * m_in_col_strides - m_colPaddingLeft;
273     const Index origInputCol =
274         (m_col_inflate_strides == 1)
275             ? inputCol
276             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
277     if (inputCol < 0 || inputCol >= m_input_cols_eff ||
278         ((m_col_inflate_strides != 1) &&
279          (inputCol != origInputCol * m_col_inflate_strides))) {
280       return Scalar(m_paddingValue);
281     }
282 
283     // Calculate row index in the original input tensor.
284     const Index rowIndex =
285         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
286     const Index rowOffset =
287         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
288     const Index inputRow = rowIndex * m_row_strides +
289                            rowOffset * m_in_row_strides - m_rowPaddingTop;
290     const Index origInputRow =
291         (m_row_inflate_strides == 1)
292             ? inputRow
293             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
294     if (inputRow < 0 || inputRow >= m_input_rows_eff ||
295         ((m_row_inflate_strides != 1) &&
296          (inputRow != origInputRow * m_row_inflate_strides))) {
297       return Scalar(m_paddingValue);
298     }
299 
300     // Calculate plane index in the original input tensor.
301     const Index planeIndex =
302         (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
303     const Index planeOffset =
304         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
305     const Index inputPlane = planeIndex * m_plane_strides +
306                              planeOffset * m_in_plane_strides -
307                              m_planePaddingTop;
308     const Index origInputPlane =
309         (m_plane_inflate_strides == 1)
310             ? inputPlane
311             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
312     if (inputPlane < 0 || inputPlane >= m_input_planes_eff ||
313         ((m_plane_inflate_strides != 1) &&
314          (inputPlane != origInputPlane * m_plane_inflate_strides))) {
315       return Scalar(m_paddingValue);
316     }
317 
318     const int depth_index =
319         static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0
320                                                                : NumDims - 1;
321     const Index depth =
322         index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
323 
324     const Index inputIndex = depth + origInputRow * m_rowInputStride +
325                              origInputCol * m_colInputStride +
326                              origInputPlane * m_planeInputStride +
327                              otherIndex * m_otherInputStride;
328 
329     return m_impl.coeff(inputIndex);
330   }
331 
332   template <int LoadMode>
333   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType
packetCustomTensorEvaluator334   packet(Index index) const {
335     EIGEN_STATIC_ASSERT(PacketSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
336     eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
337 
338     if (m_in_row_strides != 1 || m_in_col_strides != 1 ||
339         m_row_inflate_strides != 1 || m_col_inflate_strides != 1 ||
340         m_in_plane_strides != 1 || m_plane_inflate_strides != 1) {
341       return packetWithPossibleZero(index);
342     }
343 
344     const Index indices[2] = {index, index + PacketSize - 1};
345     const Index patchIndex = indices[0] / m_fastPatchStride;
346     if (patchIndex != indices[1] / m_fastPatchStride) {
347       return packetWithPossibleZero(index);
348     }
349     const Index otherIndex =
350         (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride;
351     eigen_assert(otherIndex == indices[1] / m_fastOtherStride);
352 
353     // Find the offset of the element wrt the location of the first element.
354     const Index patchOffsets[2] = {
355         (indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth,
356         (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth};
357 
358     const Index patch3DIndex =
359         (NumDims == 5)
360             ? patchIndex
361             : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride;
362     eigen_assert(patch3DIndex ==
363                  (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride);
364 
365     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
366     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
367                                  patchOffsets[1] / m_fastColStride};
368 
369     // Calculate col indices in the original input tensor.
370     const Index inputCols[2] = {
371         colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft,
372         colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft};
373     if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) {
374       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
375     }
376 
377     if (inputCols[0] != inputCols[1]) {
378       return packetWithPossibleZero(index);
379     }
380 
381     const Index rowIndex =
382         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
383     const Index rowOffsets[2] = {
384         (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
385         (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
386     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
387     // Calculate col indices in the original input tensor.
388     const Index inputRows[2] = {
389         rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop,
390         rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop};
391 
392     if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) {
393       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
394     }
395 
396     if (inputRows[0] != inputRows[1]) {
397       return packetWithPossibleZero(index);
398     }
399 
400     const Index planeIndex =
401         (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
402     const Index planeOffsets[2] = {
403         patchOffsets[0] - colOffsets[0] * m_colStride -
404             rowOffsets[0] * m_rowStride,
405         patchOffsets[1] - colOffsets[1] * m_colStride -
406             rowOffsets[1] * m_rowStride};
407     eigen_assert(planeOffsets[0] <= planeOffsets[1]);
408     const Index inputPlanes[2] = {
409         planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop,
410         planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop};
411 
412     if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) {
413       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
414     }
415 
416     if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
417       // no padding
418       const int depth_index =
419           static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0
420                                                                  : NumDims - 1;
421       const Index depth =
422           index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
423       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
424                                inputCols[0] * m_colInputStride +
425                                m_planeInputStride * inputPlanes[0] +
426                                otherIndex * m_otherInputStride;
427       return m_impl.template packet<Unaligned>(inputIndex);
428     }
429 
430     return packetWithPossibleZero(index);
431   }
432 
433   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeffCustomTensorEvaluator434   costPerCoeff(bool vectorized) const {
435     const double compute_cost = 10 * TensorOpCost::DivCost<Index>() +
436                                 21 * TensorOpCost::MulCost<Index>() +
437                                 8 * TensorOpCost::AddCost<Index>();
438     return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
439   }
440 
dataCustomTensorEvaluator441   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
442 
implCustomTensorEvaluator443   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
444 
planePaddingTopCustomTensorEvaluator445   Index planePaddingTop() const { return m_planePaddingTop; }
rowPaddingTopCustomTensorEvaluator446   Index rowPaddingTop() const { return m_rowPaddingTop; }
colPaddingLeftCustomTensorEvaluator447   Index colPaddingLeft() const { return m_colPaddingLeft; }
outputPlanesCustomTensorEvaluator448   Index outputPlanes() const { return m_outputPlanes; }
outputRowsCustomTensorEvaluator449   Index outputRows() const { return m_outputRows; }
outputColsCustomTensorEvaluator450   Index outputCols() const { return m_outputCols; }
userPlaneStrideCustomTensorEvaluator451   Index userPlaneStride() const { return m_plane_strides; }
userRowStrideCustomTensorEvaluator452   Index userRowStride() const { return m_row_strides; }
userColStrideCustomTensorEvaluator453   Index userColStride() const { return m_col_strides; }
userInPlaneStrideCustomTensorEvaluator454   Index userInPlaneStride() const { return m_in_plane_strides; }
userInRowStrideCustomTensorEvaluator455   Index userInRowStride() const { return m_in_row_strides; }
userInColStrideCustomTensorEvaluator456   Index userInColStride() const { return m_in_col_strides; }
planeInflateStrideCustomTensorEvaluator457   Index planeInflateStride() const { return m_plane_inflate_strides; }
rowInflateStrideCustomTensorEvaluator458   Index rowInflateStride() const { return m_row_inflate_strides; }
colInflateStrideCustomTensorEvaluator459   Index colInflateStride() const { return m_col_inflate_strides; }
460 
461   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType
coeffCustomTensorEvaluator462   coeff(const array<Index, NumDims>& coords) const {
463     // ColMajor
464     //   0: depth, 1: patch_planes, 2: patch_rows, 3: patch_cols, 4: number of
465     //   patches, 5: batches
466     // RowMajor
467     //   0: batches, 1: number of patches, 2: patch_cols , 3: patch_rows, 4:
468     //   patch_planes, 5: depth
469     const Index patch3DIndex =
470         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 4 : 1];
471     const Index colOffset =
472         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 3 : 2];
473     const Index rowOffset =
474         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 3];
475     const Index planeOffset =
476         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 4];
477 
478     array<Index, NumDims - 1> inputCoords;
479 
480     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
481     const Index inputCol = colIndex * m_col_strides +
482                            colOffset * m_in_col_strides - m_colPaddingLeft;
483     const Index origInputCol =
484         (m_col_inflate_strides == 1)
485             ? inputCol
486             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
487     if (inputCol < 0 || inputCol >= m_input_cols_eff ||
488         ((m_col_inflate_strides != 1) &&
489          (inputCol != origInputCol * m_col_inflate_strides))) {
490       return Scalar(m_paddingValue);
491     }
492 
493     const Index rowIndex =
494         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
495     const Index inputRow = rowIndex * m_row_strides +
496                            rowOffset * m_in_row_strides - m_rowPaddingTop;
497     const Index origInputRow =
498         (m_row_inflate_strides == 1)
499             ? inputRow
500             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
501     if (inputRow < 0 || inputRow >= m_input_rows_eff ||
502         ((m_row_inflate_strides != 1) &&
503          (inputRow != origInputRow * m_row_inflate_strides))) {
504       return Scalar(m_paddingValue);
505     }
506 
507     const Index planeIndex =
508         patch3DIndex - colIndex * m_outputPlanesRows - rowIndex * m_outputRows;
509     const Index inputPlane = planeIndex * m_plane_strides +
510                              planeOffset * m_in_plane_strides -
511                              m_planePaddingTop;
512     const Index origInputPlane =
513         (m_plane_inflate_strides == 1)
514             ? inputPlane
515             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
516     if (inputPlane < 0 || inputPlane >= m_input_planes_eff ||
517         ((m_plane_inflate_strides != 1) &&
518          (inputPlane != origInputPlane * m_plane_inflate_strides))) {
519       return Scalar(m_paddingValue);
520     }
521 
522     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
523       inputCoords[0] = coords[0];  // depth
524       inputCoords[1] = origInputPlane;
525       inputCoords[2] = origInputRow;
526       inputCoords[3] = origInputCol;
527       inputCoords[4] = coords[5];  // batch
528     } else {
529       inputCoords[4] = coords[5];  // depth
530       inputCoords[3] = origInputPlane;
531       inputCoords[2] = origInputRow;
532       inputCoords[1] = origInputCol;
533       inputCoords[0] = coords[0];  // batch
534     }
535     if (TensorEvaluator<ArgType, Device>::CoordAccess) {
536       return m_impl.coeff(inputCoords);
537     } else {
538       Index inputIndex;
539       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
540         inputIndex = inputCoords[4] * m_otherInputStride +
541                      inputCoords[3] * m_colInputStride +
542                      inputCoords[2] * m_rowInputStride +
543                      inputCoords[1] * m_planeInputStride + inputCoords[0];
544       } else {
545         inputIndex = inputCoords[0] * m_otherInputStride +
546                      inputCoords[1] * m_colInputStride +
547                      inputCoords[2] * m_rowInputStride +
548                      inputCoords[3] * m_planeInputStride + inputCoords[4];
549       }
550       return m_impl.coeff(inputIndex);
551     }
552   }
553 
554  protected:
555   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType
packetWithPossibleZeroCustomTensorEvaluator556   packetWithPossibleZero(Index index) const {
557     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type
558         values[PacketSize];
559     for (int i = 0; i < PacketSize; ++i) {
560       values[i] = coeff(index + i);
561     }
562     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
563     return rslt;
564   }
565 
566   Dimensions m_dimensions;
567 
568   // Parameters passed to the constructor.
569   Index m_plane_strides;
570   Index m_row_strides;
571   Index m_col_strides;
572 
573   Index m_outputPlanes;
574   Index m_outputRows;
575   Index m_outputCols;
576 
577   Index m_planePaddingTop;
578   Index m_rowPaddingTop;
579   Index m_colPaddingLeft;
580 
581   Index m_in_plane_strides;
582   Index m_in_row_strides;
583   Index m_in_col_strides;
584 
585   Index m_plane_inflate_strides;
586   Index m_row_inflate_strides;
587   Index m_col_inflate_strides;
588 
589   // Cached input size.
590   Index m_inputDepth;
591   Index m_inputPlanes;
592   Index m_inputRows;
593   Index m_inputCols;
594 
595   // Other cached variables.
596   Index m_outputPlanesRows;
597 
598   // Effective input/patch post-inflation size.
599   Index m_input_planes_eff;
600   Index m_input_rows_eff;
601   Index m_input_cols_eff;
602   Index m_patch_planes_eff;
603   Index m_patch_rows_eff;
604   Index m_patch_cols_eff;
605 
606   // Strides for the output tensor.
607   Index m_otherStride;
608   Index m_patchStride;
609   Index m_rowStride;
610   Index m_colStride;
611 
612   // Strides for the input tensor.
613   Index m_planeInputStride;
614   Index m_rowInputStride;
615   Index m_colInputStride;
616   Index m_otherInputStride;
617 
618   internal::TensorIntDivisor<Index> m_fastOtherStride;
619   internal::TensorIntDivisor<Index> m_fastPatchStride;
620   internal::TensorIntDivisor<Index> m_fastColStride;
621   internal::TensorIntDivisor<Index> m_fastRowStride;
622   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
623   internal::TensorIntDivisor<Index> m_fastInputRowStride;
624   internal::TensorIntDivisor<Index> m_fastInputColStride;
625   internal::TensorIntDivisor<Index> m_fastInputColsEff;
626   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
627   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
628   internal::TensorIntDivisor<Index> m_fastOutputDepth;
629 
630   Scalar m_paddingValue;
631 
632   TensorEvaluator<ArgType, Device> m_impl;
633 };
634 
635 // Override the default TensorEvaluator for TensorVolumePatchOp for CPU.
636 #define OVERRIDE_EVALUATOR(Device)                                          \
637   template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols,            \
638             typename ArgType>                                               \
639   struct TensorEvaluator<                                                   \
640       const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device>       \
641       : public CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device> { \
642     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(                  \
643         const typename CustomTensorEvaluator<Planes, Rows, Cols, ArgType,   \
644                                              Device>::XprType& op,          \
645         const Device& device)                                               \
646         : CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device>(       \
647               op, device) {}                                                \
648   };
649 
650 OVERRIDE_EVALUATOR(Eigen::ThreadPoolDevice);
651 OVERRIDE_EVALUATOR(Eigen::DefaultDevice);
652 
653 #undef OVERRIDE_EVALUATOR
654 
655 };  // namespace Eigen
656 
657 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
658