1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 4 #ifndef EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H 5 #define EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H 6 7 namespace Eigen { 8 9 /** \class TensorVolumePatch 10 * \ingroup CXX11_Tensor_Module 11 * 12 * \brief Patch extraction specialized for processing of volumetric data. 13 * This assumes that the input has a least 4 dimensions ordered as follows: 14 * - channels 15 * - planes 16 * - rows 17 * - columns 18 * - (optional) additional dimensions such as time or batch size. 19 * Calling the volume patch code with patch_planes, patch_rows, and patch_cols 20 * is equivalent to calling the regular patch extraction code with parameters 21 * d, patch_planes, patch_rows, patch_cols, and 1 for all the additional 22 * dimensions. 23 */ 24 namespace internal { 25 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType> 26 struct traits<TensorVolumePatchOp<Planes, Rows, Cols, XprType> > : public traits<XprType> 27 { 28 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 29 typedef traits<XprType> XprTraits; 30 typedef typename XprTraits::StorageKind StorageKind; 31 typedef typename XprTraits::Index Index; 32 typedef typename XprType::Nested Nested; 33 typedef typename remove_reference<Nested>::type _Nested; 34 static const int NumDimensions = XprTraits::NumDimensions + 1; 35 static const int Layout = XprTraits::Layout; 36 }; 37 38 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType> 39 struct eval<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, Eigen::Dense> 40 { 41 typedef const TensorVolumePatchOp<Planes, Rows, Cols, XprType>& type; 42 }; 43 44 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType> 45 struct nested<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, 1, typename eval<TensorVolumePatchOp<Planes, Rows, Cols, XprType> >::type> 46 { 47 typedef TensorVolumePatchOp<Planes, Rows, Cols, XprType> type; 48 }; 49 50 } // end namespace internal 51 52 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename XprType> 53 class TensorVolumePatchOp : public TensorBase<TensorVolumePatchOp<Planes, Rows, Cols, XprType>, ReadOnlyAccessors> 54 { 55 public: 56 typedef typename Eigen::internal::traits<TensorVolumePatchOp>::Scalar Scalar; 57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 58 typedef typename XprType::CoeffReturnType CoeffReturnType; 59 typedef typename Eigen::internal::nested<TensorVolumePatchOp>::type Nested; 60 typedef typename Eigen::internal::traits<TensorVolumePatchOp>::StorageKind StorageKind; 61 typedef typename Eigen::internal::traits<TensorVolumePatchOp>::Index Index; 62 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorVolumePatchOp(const XprType& expr, DenseIndex patch_planes, DenseIndex patch_rows, DenseIndex patch_cols, 64 DenseIndex plane_strides, DenseIndex row_strides, DenseIndex col_strides, 65 DenseIndex in_plane_strides, DenseIndex in_row_strides, DenseIndex in_col_strides, 66 DenseIndex plane_inflate_strides, DenseIndex row_inflate_strides, DenseIndex col_inflate_strides, 67 PaddingType padding_type, Scalar padding_value) 68 : m_xpr(expr), m_patch_planes(patch_planes), m_patch_rows(patch_rows), m_patch_cols(patch_cols), 69 m_plane_strides(plane_strides), m_row_strides(row_strides), m_col_strides(col_strides), 70 m_in_plane_strides(in_plane_strides), m_in_row_strides(in_row_strides), m_in_col_strides(in_col_strides), 71 m_plane_inflate_strides(plane_inflate_strides), m_row_inflate_strides(row_inflate_strides), m_col_inflate_strides(col_inflate_strides), 72 m_padding_explicit(false), m_padding_top_z(0), m_padding_bottom_z(0), m_padding_top(0), m_padding_bottom(0), m_padding_left(0), m_padding_right(0), 73 m_padding_type(padding_type), m_padding_value(padding_value) {} 74 75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorVolumePatchOp(const XprType& expr, DenseIndex patch_planes, DenseIndex patch_rows, DenseIndex patch_cols, 76 DenseIndex plane_strides, DenseIndex row_strides, DenseIndex col_strides, 77 DenseIndex in_plane_strides, DenseIndex in_row_strides, DenseIndex in_col_strides, 78 DenseIndex plane_inflate_strides, DenseIndex row_inflate_strides, DenseIndex col_inflate_strides, 79 DenseIndex padding_top_z, DenseIndex padding_bottom_z, 80 DenseIndex padding_top, DenseIndex padding_bottom, 81 DenseIndex padding_left, DenseIndex padding_right, 82 Scalar padding_value) 83 : m_xpr(expr), m_patch_planes(patch_planes), m_patch_rows(patch_rows), m_patch_cols(patch_cols), 84 m_plane_strides(plane_strides), m_row_strides(row_strides), m_col_strides(col_strides), 85 m_in_plane_strides(in_plane_strides), m_in_row_strides(in_row_strides), m_in_col_strides(in_col_strides), 86 m_plane_inflate_strides(plane_inflate_strides), m_row_inflate_strides(row_inflate_strides), m_col_inflate_strides(col_inflate_strides), 87 m_padding_explicit(true), m_padding_top_z(padding_top_z), m_padding_bottom_z(padding_bottom_z), m_padding_top(padding_top), m_padding_bottom(padding_bottom), 88 m_padding_left(padding_left), m_padding_right(padding_right), 89 m_padding_type(PADDING_VALID), m_padding_value(padding_value) {} 90 91 EIGEN_DEVICE_FUNC 92 DenseIndex patch_planes() const { return m_patch_planes; } 93 EIGEN_DEVICE_FUNC 94 DenseIndex patch_rows() const { return m_patch_rows; } 95 EIGEN_DEVICE_FUNC 96 DenseIndex patch_cols() const { return m_patch_cols; } 97 EIGEN_DEVICE_FUNC 98 DenseIndex plane_strides() const { return m_plane_strides; } 99 EIGEN_DEVICE_FUNC 100 DenseIndex row_strides() const { return m_row_strides; } 101 EIGEN_DEVICE_FUNC 102 DenseIndex col_strides() const { return m_col_strides; } 103 EIGEN_DEVICE_FUNC 104 DenseIndex in_plane_strides() const { return m_in_plane_strides; } 105 EIGEN_DEVICE_FUNC 106 DenseIndex in_row_strides() const { return m_in_row_strides; } 107 EIGEN_DEVICE_FUNC 108 DenseIndex in_col_strides() const { return m_in_col_strides; } 109 EIGEN_DEVICE_FUNC 110 DenseIndex plane_inflate_strides() const { return m_plane_inflate_strides; } 111 EIGEN_DEVICE_FUNC 112 DenseIndex row_inflate_strides() const { return m_row_inflate_strides; } 113 EIGEN_DEVICE_FUNC 114 DenseIndex col_inflate_strides() const { return m_col_inflate_strides; } 115 EIGEN_DEVICE_FUNC 116 bool padding_explicit() const { return m_padding_explicit; } 117 EIGEN_DEVICE_FUNC 118 DenseIndex padding_top_z() const { return m_padding_top_z; } 119 EIGEN_DEVICE_FUNC 120 DenseIndex padding_bottom_z() const { return m_padding_bottom_z; } 121 EIGEN_DEVICE_FUNC 122 DenseIndex padding_top() const { return m_padding_top; } 123 EIGEN_DEVICE_FUNC 124 DenseIndex padding_bottom() const { return m_padding_bottom; } 125 EIGEN_DEVICE_FUNC 126 DenseIndex padding_left() const { return m_padding_left; } 127 EIGEN_DEVICE_FUNC 128 DenseIndex padding_right() const { return m_padding_right; } 129 EIGEN_DEVICE_FUNC 130 PaddingType padding_type() const { return m_padding_type; } 131 EIGEN_DEVICE_FUNC 132 Scalar padding_value() const { return m_padding_value; } 133 134 EIGEN_DEVICE_FUNC 135 const typename internal::remove_all<typename XprType::Nested>::type& 136 expression() const { return m_xpr; } 137 138 protected: 139 typename XprType::Nested m_xpr; 140 const DenseIndex m_patch_planes; 141 const DenseIndex m_patch_rows; 142 const DenseIndex m_patch_cols; 143 const DenseIndex m_plane_strides; 144 const DenseIndex m_row_strides; 145 const DenseIndex m_col_strides; 146 const DenseIndex m_in_plane_strides; 147 const DenseIndex m_in_row_strides; 148 const DenseIndex m_in_col_strides; 149 const DenseIndex m_plane_inflate_strides; 150 const DenseIndex m_row_inflate_strides; 151 const DenseIndex m_col_inflate_strides; 152 const bool m_padding_explicit; 153 const DenseIndex m_padding_top_z; 154 const DenseIndex m_padding_bottom_z; 155 const DenseIndex m_padding_top; 156 const DenseIndex m_padding_bottom; 157 const DenseIndex m_padding_left; 158 const DenseIndex m_padding_right; 159 const PaddingType m_padding_type; 160 const Scalar m_padding_value; 161 }; 162 163 164 // Eval as rvalue 165 template<DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device> 166 struct TensorEvaluator<const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device> 167 { 168 typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType; 169 typedef typename XprType::Index Index; 170 static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 171 static const int NumDims = NumInputDims + 1; 172 typedef DSizes<Index, NumDims> Dimensions; 173 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 174 typedef typename XprType::CoeffReturnType CoeffReturnType; 175 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 176 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 177 178 enum { 179 IsAligned = false, 180 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 181 BlockAccess = false, 182 Layout = TensorEvaluator<ArgType, Device>::Layout, 183 CoordAccess = false, 184 RawAccess = false 185 }; 186 187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 188 : m_impl(op.expression(), device) 189 { 190 EIGEN_STATIC_ASSERT((NumDims >= 5), YOU_MADE_A_PROGRAMMING_MISTAKE); 191 192 m_paddingValue = op.padding_value(); 193 194 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); 195 196 // Cache a few variables. 197 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 198 m_inputDepth = input_dims[0]; 199 m_inputPlanes = input_dims[1]; 200 m_inputRows = input_dims[2]; 201 m_inputCols = input_dims[3]; 202 } else { 203 m_inputDepth = input_dims[NumInputDims-1]; 204 m_inputPlanes = input_dims[NumInputDims-2]; 205 m_inputRows = input_dims[NumInputDims-3]; 206 m_inputCols = input_dims[NumInputDims-4]; 207 } 208 209 m_plane_strides = op.plane_strides(); 210 m_row_strides = op.row_strides(); 211 m_col_strides = op.col_strides(); 212 213 // Input strides and effective input/patch size 214 m_in_plane_strides = op.in_plane_strides(); 215 m_in_row_strides = op.in_row_strides(); 216 m_in_col_strides = op.in_col_strides(); 217 m_plane_inflate_strides = op.plane_inflate_strides(); 218 m_row_inflate_strides = op.row_inflate_strides(); 219 m_col_inflate_strides = op.col_inflate_strides(); 220 221 // The "effective" spatial size after inflating data with zeros. 222 m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1; 223 m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1; 224 m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1; 225 m_patch_planes_eff = op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1); 226 m_patch_rows_eff = op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1); 227 m_patch_cols_eff = op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1); 228 229 if (op.padding_explicit()) { 230 m_outputPlanes = numext::ceil((m_input_planes_eff + op.padding_top_z() + op.padding_bottom_z() - m_patch_planes_eff + 1.f) / static_cast<float>(m_plane_strides)); 231 m_outputRows = numext::ceil((m_input_rows_eff + op.padding_top() + op.padding_bottom() - m_patch_rows_eff + 1.f) / static_cast<float>(m_row_strides)); 232 m_outputCols = numext::ceil((m_input_cols_eff + op.padding_left() + op.padding_right() - m_patch_cols_eff + 1.f) / static_cast<float>(m_col_strides)); 233 m_planePaddingTop = op.padding_top_z(); 234 m_rowPaddingTop = op.padding_top(); 235 m_colPaddingLeft = op.padding_left(); 236 } else { 237 // Computing padding from the type 238 switch (op.padding_type()) { 239 case PADDING_VALID: 240 m_outputPlanes = numext::ceil((m_input_planes_eff - m_patch_planes_eff + 1.f) / static_cast<float>(m_plane_strides)); 241 m_outputRows = numext::ceil((m_input_rows_eff - m_patch_rows_eff + 1.f) / static_cast<float>(m_row_strides)); 242 m_outputCols = numext::ceil((m_input_cols_eff - m_patch_cols_eff + 1.f) / static_cast<float>(m_col_strides)); 243 m_planePaddingTop = 0; 244 m_rowPaddingTop = 0; 245 m_colPaddingLeft = 0; 246 break; 247 case PADDING_SAME: { 248 m_outputPlanes = numext::ceil(m_input_planes_eff / static_cast<float>(m_plane_strides)); 249 m_outputRows = numext::ceil(m_input_rows_eff / static_cast<float>(m_row_strides)); 250 m_outputCols = numext::ceil(m_input_cols_eff / static_cast<float>(m_col_strides)); 251 const Index dz = m_outputPlanes * m_plane_strides + m_patch_planes_eff - 1 - m_input_planes_eff; 252 const Index dy = m_outputRows * m_row_strides + m_patch_rows_eff - 1 - m_input_rows_eff; 253 const Index dx = m_outputCols * m_col_strides + m_patch_cols_eff - 1 - m_input_cols_eff; 254 m_planePaddingTop = dz - dz / 2; 255 m_rowPaddingTop = dy - dy / 2; 256 m_colPaddingLeft = dx - dx / 2; 257 break; 258 } 259 default: 260 eigen_assert(false && "unexpected padding"); 261 } 262 } 263 eigen_assert(m_outputRows > 0); 264 eigen_assert(m_outputCols > 0); 265 eigen_assert(m_outputPlanes > 0); 266 267 // Dimensions for result of extraction. 268 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 269 // ColMajor 270 // 0: depth 271 // 1: patch_planes 272 // 2: patch_rows 273 // 3: patch_cols 274 // 4: number of patches 275 // 5 and beyond: anything else (such as batch). 276 m_dimensions[0] = input_dims[0]; 277 m_dimensions[1] = op.patch_planes(); 278 m_dimensions[2] = op.patch_rows(); 279 m_dimensions[3] = op.patch_cols(); 280 m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols; 281 for (int i = 5; i < NumDims; ++i) { 282 m_dimensions[i] = input_dims[i-1]; 283 } 284 } else { 285 // RowMajor 286 // NumDims-1: depth 287 // NumDims-2: patch_planes 288 // NumDims-3: patch_rows 289 // NumDims-4: patch_cols 290 // NumDims-5: number of patches 291 // NumDims-6 and beyond: anything else (such as batch). 292 m_dimensions[NumDims-1] = input_dims[NumInputDims-1]; 293 m_dimensions[NumDims-2] = op.patch_planes(); 294 m_dimensions[NumDims-3] = op.patch_rows(); 295 m_dimensions[NumDims-4] = op.patch_cols(); 296 m_dimensions[NumDims-5] = m_outputPlanes * m_outputRows * m_outputCols; 297 for (int i = NumDims-6; i >= 0; --i) { 298 m_dimensions[i] = input_dims[i]; 299 } 300 } 301 302 // Strides for the output tensor. 303 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 304 m_rowStride = m_dimensions[1]; 305 m_colStride = m_dimensions[2] * m_rowStride; 306 m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0]; 307 m_otherStride = m_patchStride * m_dimensions[4]; 308 } else { 309 m_rowStride = m_dimensions[NumDims-2]; 310 m_colStride = m_dimensions[NumDims-3] * m_rowStride; 311 m_patchStride = m_colStride * m_dimensions[NumDims-4] * m_dimensions[NumDims-1]; 312 m_otherStride = m_patchStride * m_dimensions[NumDims-5]; 313 } 314 315 // Strides for navigating through the input tensor. 316 m_planeInputStride = m_inputDepth; 317 m_rowInputStride = m_inputDepth * m_inputPlanes; 318 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 319 m_otherInputStride = m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 320 321 m_outputPlanesRows = m_outputPlanes * m_outputRows; 322 323 // Fast representations of different variables. 324 m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride); 325 m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride); 326 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 327 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 328 m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_row_inflate_strides); 329 m_fastInputColStride = internal::TensorIntDivisor<Index>(m_col_inflate_strides); 330 m_fastInputPlaneStride = internal::TensorIntDivisor<Index>(m_plane_inflate_strides); 331 m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff); 332 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 333 m_fastOutputPlanesRows = internal::TensorIntDivisor<Index>(m_outputPlanesRows); 334 335 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 336 m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]); 337 } else { 338 m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[NumDims-1]); 339 } 340 } 341 342 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 343 344 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 345 m_impl.evalSubExprsIfNeeded(NULL); 346 return true; 347 } 348 349 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 350 m_impl.cleanup(); 351 } 352 353 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 354 { 355 // Patch index corresponding to the passed in index. 356 const Index patchIndex = index / m_fastPatchStride; 357 358 // Spatial offset within the patch. This has to be translated into 3D 359 // coordinates within the patch. 360 const Index patchOffset = (index - patchIndex * m_patchStride) / m_fastOutputDepth; 361 362 // Batch, etc. 363 const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride; 364 const Index patch3DIndex = (NumDims == 5) ? patchIndex : (index - otherIndex * m_otherStride) / m_fastPatchStride; 365 366 // Calculate column index in the input original tensor. 367 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 368 const Index colOffset = patchOffset / m_fastColStride; 369 const Index inputCol = colIndex * m_col_strides + colOffset * m_in_col_strides - m_colPaddingLeft; 370 const Index origInputCol = (m_col_inflate_strides == 1) ? inputCol : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 371 if (inputCol < 0 || inputCol >= m_input_cols_eff || 372 ((m_col_inflate_strides != 1) && (inputCol != origInputCol * m_col_inflate_strides))) { 373 return Scalar(m_paddingValue); 374 } 375 376 // Calculate row index in the original input tensor. 377 const Index rowIndex = (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 378 const Index rowOffset = (patchOffset - colOffset * m_colStride) / m_fastRowStride; 379 const Index inputRow = rowIndex * m_row_strides + rowOffset * m_in_row_strides - m_rowPaddingTop; 380 const Index origInputRow = (m_row_inflate_strides == 1) ? inputRow : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 381 if (inputRow < 0 || inputRow >= m_input_rows_eff || 382 ((m_row_inflate_strides != 1) && (inputRow != origInputRow * m_row_inflate_strides))) { 383 return Scalar(m_paddingValue); 384 } 385 386 // Calculate plane index in the original input tensor. 387 const Index planeIndex = (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 388 const Index planeOffset = patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 389 const Index inputPlane = planeIndex * m_plane_strides + planeOffset * m_in_plane_strides - m_planePaddingTop; 390 const Index origInputPlane = (m_plane_inflate_strides == 1) ? inputPlane : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 391 if (inputPlane < 0 || inputPlane >= m_input_planes_eff || 392 ((m_plane_inflate_strides != 1) && (inputPlane != origInputPlane * m_plane_inflate_strides))) { 393 return Scalar(m_paddingValue); 394 } 395 396 const int depth_index = static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : NumDims - 1; 397 const Index depth = index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 398 399 const Index inputIndex = depth + 400 origInputRow * m_rowInputStride + 401 origInputCol * m_colInputStride + 402 origInputPlane * m_planeInputStride + 403 otherIndex * m_otherInputStride; 404 405 return m_impl.coeff(inputIndex); 406 } 407 408 template<int LoadMode> 409 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 410 { 411 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 412 eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); 413 414 if (m_in_row_strides != 1 || m_in_col_strides != 1 || m_row_inflate_strides != 1 || m_col_inflate_strides != 1 || 415 m_in_plane_strides != 1 || m_plane_inflate_strides != 1) { 416 return packetWithPossibleZero(index); 417 } 418 419 const Index indices[2] = {index, index + PacketSize - 1}; 420 const Index patchIndex = indices[0] / m_fastPatchStride; 421 if (patchIndex != indices[1] / m_fastPatchStride) { 422 return packetWithPossibleZero(index); 423 } 424 const Index otherIndex = (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride; 425 eigen_assert(otherIndex == indices[1] / m_fastOtherStride); 426 427 // Find the offset of the element wrt the location of the first element. 428 const Index patchOffsets[2] = {(indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth, 429 (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth}; 430 431 const Index patch3DIndex = (NumDims == 5) ? patchIndex : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride; 432 eigen_assert(patch3DIndex == (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride); 433 434 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 435 const Index colOffsets[2] = { 436 patchOffsets[0] / m_fastColStride, 437 patchOffsets[1] / m_fastColStride}; 438 439 // Calculate col indices in the original input tensor. 440 const Index inputCols[2] = { 441 colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft, 442 colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft}; 443 if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) { 444 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 445 } 446 447 if (inputCols[0] != inputCols[1]) { 448 return packetWithPossibleZero(index); 449 } 450 451 const Index rowIndex = (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 452 const Index rowOffsets[2] = { 453 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 454 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 455 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 456 // Calculate col indices in the original input tensor. 457 const Index inputRows[2] = { 458 rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop, 459 rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop}; 460 461 if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) { 462 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 463 } 464 465 if (inputRows[0] != inputRows[1]) { 466 return packetWithPossibleZero(index); 467 } 468 469 const Index planeIndex = (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 470 const Index planeOffsets[2] = { 471 patchOffsets[0] - colOffsets[0] * m_colStride - rowOffsets[0] * m_rowStride, 472 patchOffsets[1] - colOffsets[1] * m_colStride - rowOffsets[1] * m_rowStride}; 473 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 474 const Index inputPlanes[2] = { 475 planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop, 476 planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop}; 477 478 if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) { 479 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 480 } 481 482 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 483 // no padding 484 const int depth_index = static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 : NumDims - 1; 485 const Index depth = index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 486 const Index inputIndex = depth + 487 inputRows[0] * m_rowInputStride + 488 inputCols[0] * m_colInputStride + 489 m_planeInputStride * inputPlanes[0] + 490 otherIndex * m_otherInputStride; 491 return m_impl.template packet<Unaligned>(inputIndex); 492 } 493 494 return packetWithPossibleZero(index); 495 } 496 497 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 498 costPerCoeff(bool vectorized) const { 499 const double compute_cost = 500 10 * TensorOpCost::DivCost<Index>() + 21 * TensorOpCost::MulCost<Index>() + 501 8 * TensorOpCost::AddCost<Index>(); 502 return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize); 503 } 504 505 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 506 507 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; } 508 509 Index planePaddingTop() const { return m_planePaddingTop; } 510 Index rowPaddingTop() const { return m_rowPaddingTop; } 511 Index colPaddingLeft() const { return m_colPaddingLeft; } 512 Index outputPlanes() const { return m_outputPlanes; } 513 Index outputRows() const { return m_outputRows; } 514 Index outputCols() const { return m_outputCols; } 515 Index userPlaneStride() const { return m_plane_strides; } 516 Index userRowStride() const { return m_row_strides; } 517 Index userColStride() const { return m_col_strides; } 518 Index userInPlaneStride() const { return m_in_plane_strides; } 519 Index userInRowStride() const { return m_in_row_strides; } 520 Index userInColStride() const { return m_in_col_strides; } 521 Index planeInflateStride() const { return m_plane_inflate_strides; } 522 Index rowInflateStride() const { return m_row_inflate_strides; } 523 Index colInflateStride() const { return m_col_inflate_strides; } 524 525 protected: 526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetWithPossibleZero(Index index) const 527 { 528 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; 529 for (int i = 0; i < PacketSize; ++i) { 530 values[i] = coeff(index+i); 531 } 532 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 533 return rslt; 534 } 535 536 Dimensions m_dimensions; 537 538 // Parameters passed to the costructor. 539 Index m_plane_strides; 540 Index m_row_strides; 541 Index m_col_strides; 542 543 Index m_outputPlanes; 544 Index m_outputRows; 545 Index m_outputCols; 546 547 Index m_planePaddingTop; 548 Index m_rowPaddingTop; 549 Index m_colPaddingLeft; 550 551 Index m_in_plane_strides; 552 Index m_in_row_strides; 553 Index m_in_col_strides; 554 555 Index m_plane_inflate_strides; 556 Index m_row_inflate_strides; 557 Index m_col_inflate_strides; 558 559 // Cached input size. 560 Index m_inputDepth; 561 Index m_inputPlanes; 562 Index m_inputRows; 563 Index m_inputCols; 564 565 // Other cached variables. 566 Index m_outputPlanesRows; 567 568 // Effective input/patch post-inflation size. 569 Index m_input_planes_eff; 570 Index m_input_rows_eff; 571 Index m_input_cols_eff; 572 Index m_patch_planes_eff; 573 Index m_patch_rows_eff; 574 Index m_patch_cols_eff; 575 576 // Strides for the output tensor. 577 Index m_otherStride; 578 Index m_patchStride; 579 Index m_rowStride; 580 Index m_colStride; 581 582 // Strides for the input tensor. 583 Index m_planeInputStride; 584 Index m_rowInputStride; 585 Index m_colInputStride; 586 Index m_otherInputStride; 587 588 internal::TensorIntDivisor<Index> m_fastOtherStride; 589 internal::TensorIntDivisor<Index> m_fastPatchStride; 590 internal::TensorIntDivisor<Index> m_fastColStride; 591 internal::TensorIntDivisor<Index> m_fastRowStride; 592 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 593 internal::TensorIntDivisor<Index> m_fastInputRowStride; 594 internal::TensorIntDivisor<Index> m_fastInputColStride; 595 internal::TensorIntDivisor<Index> m_fastInputColsEff; 596 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 597 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 598 internal::TensorIntDivisor<Index> m_fastOutputDepth; 599 600 Scalar m_paddingValue; 601 602 TensorEvaluator<ArgType, Device> m_impl; 603 }; 604 605 606 } // end namespace Eigen 607 608 #endif // EIGEN_CXX11_TENSOR_TENSOR_VOLUME_PATCH_H 609