1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 namespace Eigen { 22 23 // Changes the interpretation of padding in TensorVolumePatchOp to be compatible 24 // with the rest of TensorFlow (odd padding is split so that more padding is put 25 // on the right end of the tensor). 26 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType, 27 typename Device> 28 struct CustomTensorEvaluator { 29 typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType; 30 typedef typename XprType::Index Index; 31 static constexpr int NumInputDims = internal::array_size< 32 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 33 static constexpr int NumDims = NumInputDims + 1; 34 typedef DSizes<Index, NumDims> Dimensions; 35 typedef 36 typename internal::remove_const<typename XprType::Scalar>::type Scalar; 37 typedef typename XprType::CoeffReturnType CoeffReturnType; 38 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 39 static constexpr Index PacketSize = 40 internal::unpacket_traits<PacketReturnType>::size; 41 42 enum { 43 IsAligned = false, 44 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 45 BlockAccess = false, 46 PreferBlockAccess = false, 47 Layout = TensorEvaluator<ArgType, Device>::Layout, 48 CoordAccess = NumDims == 6, 49 RawAccess = false 50 }; 51 52 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CustomTensorEvaluatorCustomTensorEvaluator53 CustomTensorEvaluator(const XprType& op, const Device& device) 54 : m_impl(op.expression(), device) { 55 EIGEN_STATIC_ASSERT(NumDims >= 5, YOU_MADE_A_PROGRAMMING_MISTAKE); 56 57 m_paddingValue = op.padding_value(); 58 59 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = 60 m_impl.dimensions(); 61 62 // Cache a few variables. 63 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 64 m_inputDepth = input_dims[0]; 65 m_inputPlanes = input_dims[1]; 66 m_inputRows = input_dims[2]; 67 m_inputCols = input_dims[3]; 68 } else { 69 m_inputDepth = input_dims[NumInputDims - 1]; 70 m_inputPlanes = input_dims[NumInputDims - 2]; 71 m_inputRows = input_dims[NumInputDims - 3]; 72 m_inputCols = input_dims[NumInputDims - 4]; 73 } 74 75 m_plane_strides = op.plane_strides(); 76 m_row_strides = op.row_strides(); 77 m_col_strides = op.col_strides(); 78 79 // Input strides and effective input/patch size 80 m_in_plane_strides = op.in_plane_strides(); 81 m_in_row_strides = op.in_row_strides(); 82 m_in_col_strides = op.in_col_strides(); 83 m_plane_inflate_strides = op.plane_inflate_strides(); 84 m_row_inflate_strides = op.row_inflate_strides(); 85 m_col_inflate_strides = op.col_inflate_strides(); 86 87 // The "effective" spatial size after inflating data with zeros. 88 m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1; 89 m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1; 90 m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1; 91 m_patch_planes_eff = 92 op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1); 93 m_patch_rows_eff = 94 op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1); 95 m_patch_cols_eff = 96 op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1); 97 98 if (op.padding_explicit()) { 99 m_outputPlanes = Eigen::divup( 100 m_input_planes_eff + 101 static_cast<Index>(op.padding_top_z() + op.padding_bottom_z()) - 102 m_patch_planes_eff + 1, 103 m_plane_strides); 104 m_outputRows = Eigen::divup( 105 m_input_rows_eff + 106 static_cast<Index>(op.padding_top() + op.padding_bottom()) - 107 m_patch_rows_eff + 1, 108 m_row_strides); 109 m_outputCols = Eigen::divup( 110 m_input_cols_eff + 111 static_cast<Index>(op.padding_left() + op.padding_right()) - 112 m_patch_cols_eff + 1, 113 m_col_strides); 114 m_planePaddingTop = op.padding_top_z(); 115 m_rowPaddingTop = op.padding_top(); 116 m_colPaddingLeft = op.padding_left(); 117 } else { 118 // Computing padding from the type 119 switch (op.padding_type()) { 120 case PADDING_VALID: 121 m_outputPlanes = Eigen::divup( 122 m_input_planes_eff - m_patch_planes_eff + 1, m_plane_strides); 123 m_outputRows = Eigen::divup(m_input_rows_eff - m_patch_rows_eff + 1, 124 m_row_strides); 125 m_outputCols = Eigen::divup(m_input_cols_eff - m_patch_cols_eff + 1, 126 m_col_strides); 127 m_planePaddingTop = 0; 128 m_rowPaddingTop = 0; 129 m_colPaddingLeft = 0; 130 break; 131 case PADDING_SAME: { 132 m_outputPlanes = Eigen::divup(m_input_planes_eff, m_plane_strides); 133 m_outputRows = Eigen::divup(m_input_rows_eff, m_row_strides); 134 m_outputCols = Eigen::divup(m_input_cols_eff, m_col_strides); 135 const Index dz = numext::maxi<DenseIndex>( 136 0, (m_outputPlanes - 1) * m_plane_strides + m_patch_planes_eff - 137 m_input_planes_eff); 138 const Index dy = numext::maxi<DenseIndex>( 139 0, (m_outputRows - 1) * m_row_strides + m_patch_rows_eff - 140 m_input_rows_eff); 141 const Index dx = numext::maxi<DenseIndex>( 142 0, (m_outputCols - 1) * m_col_strides + m_patch_cols_eff - 143 m_input_cols_eff); 144 m_planePaddingTop = dz / 2; 145 m_rowPaddingTop = dy / 2; 146 m_colPaddingLeft = dx / 2; 147 break; 148 } 149 default: 150 eigen_assert(false && "unexpected padding"); 151 } 152 } 153 eigen_assert(m_outputRows > 0); 154 eigen_assert(m_outputCols > 0); 155 eigen_assert(m_outputPlanes > 0); 156 157 // Dimensions for result of extraction. 158 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 159 // ColMajor 160 // 0: depth 161 // 1: patch_planes 162 // 2: patch_rows 163 // 3: patch_cols 164 // 4: number of patches 165 // 5 and beyond: anything else (such as batch). 166 m_dimensions[0] = input_dims[0]; 167 m_dimensions[1] = op.patch_planes(); 168 m_dimensions[2] = op.patch_rows(); 169 m_dimensions[3] = op.patch_cols(); 170 m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols; 171 for (int i = 5; i < NumDims; ++i) { 172 m_dimensions[i] = input_dims[i - 1]; 173 } 174 } else { 175 // RowMajor 176 // NumDims-1: depth 177 // NumDims-2: patch_planes 178 // NumDims-3: patch_rows 179 // NumDims-4: patch_cols 180 // NumDims-5: number of patches 181 // NumDims-6 and beyond: anything else (such as batch). 182 m_dimensions[NumDims - 1] = input_dims[NumInputDims - 1]; 183 m_dimensions[NumDims - 2] = op.patch_planes(); 184 m_dimensions[NumDims - 3] = op.patch_rows(); 185 m_dimensions[NumDims - 4] = op.patch_cols(); 186 m_dimensions[NumDims - 5] = m_outputPlanes * m_outputRows * m_outputCols; 187 for (int i = NumDims - 6; i >= 0; --i) { 188 m_dimensions[i] = input_dims[i]; 189 } 190 } 191 192 // Strides for the output tensor. 193 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 194 m_rowStride = m_dimensions[1]; 195 m_colStride = m_dimensions[2] * m_rowStride; 196 m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0]; 197 m_otherStride = m_patchStride * m_dimensions[4]; 198 } else { 199 m_rowStride = m_dimensions[NumDims - 2]; 200 m_colStride = m_dimensions[NumDims - 3] * m_rowStride; 201 m_patchStride = 202 m_colStride * m_dimensions[NumDims - 4] * m_dimensions[NumDims - 1]; 203 m_otherStride = m_patchStride * m_dimensions[NumDims - 5]; 204 } 205 206 // Strides for navigating through the input tensor. 207 m_planeInputStride = m_inputDepth; 208 m_rowInputStride = m_inputDepth * m_inputPlanes; 209 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 210 m_otherInputStride = 211 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 212 213 m_outputPlanesRows = m_outputPlanes * m_outputRows; 214 215 // Fast representations of different variables. 216 m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride); 217 m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride); 218 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 219 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 220 m_fastInputRowStride = 221 internal::TensorIntDivisor<Index>(m_row_inflate_strides); 222 m_fastInputColStride = 223 internal::TensorIntDivisor<Index>(m_col_inflate_strides); 224 m_fastInputPlaneStride = 225 internal::TensorIntDivisor<Index>(m_plane_inflate_strides); 226 m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff); 227 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 228 m_fastOutputPlanesRows = 229 internal::TensorIntDivisor<Index>(m_outputPlanesRows); 230 231 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 232 m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]); 233 } else { 234 m_fastOutputDepth = 235 internal::TensorIntDivisor<Index>(m_dimensions[NumDims - 1]); 236 } 237 } 238 dimensionsCustomTensorEvaluator239 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 240 return m_dimensions; 241 } 242 evalSubExprsIfNeededCustomTensorEvaluator243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded( 244 Scalar* /*data*/) { 245 m_impl.evalSubExprsIfNeeded(NULL); 246 return true; 247 } 248 cleanupCustomTensorEvaluator249 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } 250 251 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffCustomTensorEvaluator252 coeff(Index index) const { 253 // Patch index corresponding to the passed in index. 254 const Index patchIndex = index / m_fastPatchStride; 255 256 // Spatial offset within the patch. This has to be translated into 3D 257 // coordinates within the patch. 258 const Index patchOffset = 259 (index - patchIndex * m_patchStride) / m_fastOutputDepth; 260 261 // Batch, etc. 262 const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride; 263 const Index patch3DIndex = 264 (NumDims == 5) 265 ? patchIndex 266 : (index - otherIndex * m_otherStride) / m_fastPatchStride; 267 268 // Calculate column index in the input original tensor. 269 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 270 const Index colOffset = patchOffset / m_fastColStride; 271 const Index inputCol = colIndex * m_col_strides + 272 colOffset * m_in_col_strides - m_colPaddingLeft; 273 const Index origInputCol = 274 (m_col_inflate_strides == 1) 275 ? inputCol 276 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 277 if (inputCol < 0 || inputCol >= m_input_cols_eff || 278 ((m_col_inflate_strides != 1) && 279 (inputCol != origInputCol * m_col_inflate_strides))) { 280 return Scalar(m_paddingValue); 281 } 282 283 // Calculate row index in the original input tensor. 284 const Index rowIndex = 285 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 286 const Index rowOffset = 287 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 288 const Index inputRow = rowIndex * m_row_strides + 289 rowOffset * m_in_row_strides - m_rowPaddingTop; 290 const Index origInputRow = 291 (m_row_inflate_strides == 1) 292 ? inputRow 293 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 294 if (inputRow < 0 || inputRow >= m_input_rows_eff || 295 ((m_row_inflate_strides != 1) && 296 (inputRow != origInputRow * m_row_inflate_strides))) { 297 return Scalar(m_paddingValue); 298 } 299 300 // Calculate plane index in the original input tensor. 301 const Index planeIndex = 302 (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 303 const Index planeOffset = 304 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 305 const Index inputPlane = planeIndex * m_plane_strides + 306 planeOffset * m_in_plane_strides - 307 m_planePaddingTop; 308 const Index origInputPlane = 309 (m_plane_inflate_strides == 1) 310 ? inputPlane 311 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 312 if (inputPlane < 0 || inputPlane >= m_input_planes_eff || 313 ((m_plane_inflate_strides != 1) && 314 (inputPlane != origInputPlane * m_plane_inflate_strides))) { 315 return Scalar(m_paddingValue); 316 } 317 318 const int depth_index = 319 static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 320 : NumDims - 1; 321 const Index depth = 322 index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 323 324 const Index inputIndex = depth + origInputRow * m_rowInputStride + 325 origInputCol * m_colInputStride + 326 origInputPlane * m_planeInputStride + 327 otherIndex * m_otherInputStride; 328 329 return m_impl.coeff(inputIndex); 330 } 331 332 template <int LoadMode> 333 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetCustomTensorEvaluator334 packet(Index index) const { 335 EIGEN_STATIC_ASSERT(PacketSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 336 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize()); 337 338 if (m_in_row_strides != 1 || m_in_col_strides != 1 || 339 m_row_inflate_strides != 1 || m_col_inflate_strides != 1 || 340 m_in_plane_strides != 1 || m_plane_inflate_strides != 1) { 341 return packetWithPossibleZero(index); 342 } 343 344 const Index indices[2] = {index, index + PacketSize - 1}; 345 const Index patchIndex = indices[0] / m_fastPatchStride; 346 if (patchIndex != indices[1] / m_fastPatchStride) { 347 return packetWithPossibleZero(index); 348 } 349 const Index otherIndex = 350 (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride; 351 eigen_assert(otherIndex == indices[1] / m_fastOtherStride); 352 353 // Find the offset of the element wrt the location of the first element. 354 const Index patchOffsets[2] = { 355 (indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth, 356 (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth}; 357 358 const Index patch3DIndex = 359 (NumDims == 5) 360 ? patchIndex 361 : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride; 362 eigen_assert(patch3DIndex == 363 (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride); 364 365 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 366 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 367 patchOffsets[1] / m_fastColStride}; 368 369 // Calculate col indices in the original input tensor. 370 const Index inputCols[2] = { 371 colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft, 372 colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft}; 373 if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) { 374 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 375 } 376 377 if (inputCols[0] != inputCols[1]) { 378 return packetWithPossibleZero(index); 379 } 380 381 const Index rowIndex = 382 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 383 const Index rowOffsets[2] = { 384 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 385 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 386 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 387 // Calculate col indices in the original input tensor. 388 const Index inputRows[2] = { 389 rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop, 390 rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop}; 391 392 if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) { 393 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 394 } 395 396 if (inputRows[0] != inputRows[1]) { 397 return packetWithPossibleZero(index); 398 } 399 400 const Index planeIndex = 401 (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 402 const Index planeOffsets[2] = { 403 patchOffsets[0] - colOffsets[0] * m_colStride - 404 rowOffsets[0] * m_rowStride, 405 patchOffsets[1] - colOffsets[1] * m_colStride - 406 rowOffsets[1] * m_rowStride}; 407 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 408 const Index inputPlanes[2] = { 409 planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop, 410 planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop}; 411 412 if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) { 413 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 414 } 415 416 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 417 // no padding 418 const int depth_index = 419 static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 420 : NumDims - 1; 421 const Index depth = 422 index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 423 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 424 inputCols[0] * m_colInputStride + 425 m_planeInputStride * inputPlanes[0] + 426 otherIndex * m_otherInputStride; 427 return m_impl.template packet<Unaligned>(inputIndex); 428 } 429 430 return packetWithPossibleZero(index); 431 } 432 433 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeffCustomTensorEvaluator434 costPerCoeff(bool vectorized) const { 435 const double compute_cost = 10 * TensorOpCost::DivCost<Index>() + 436 21 * TensorOpCost::MulCost<Index>() + 437 8 * TensorOpCost::AddCost<Index>(); 438 return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize); 439 } 440 dataCustomTensorEvaluator441 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 442 implCustomTensorEvaluator443 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; } 444 planePaddingTopCustomTensorEvaluator445 Index planePaddingTop() const { return m_planePaddingTop; } rowPaddingTopCustomTensorEvaluator446 Index rowPaddingTop() const { return m_rowPaddingTop; } colPaddingLeftCustomTensorEvaluator447 Index colPaddingLeft() const { return m_colPaddingLeft; } outputPlanesCustomTensorEvaluator448 Index outputPlanes() const { return m_outputPlanes; } outputRowsCustomTensorEvaluator449 Index outputRows() const { return m_outputRows; } outputColsCustomTensorEvaluator450 Index outputCols() const { return m_outputCols; } userPlaneStrideCustomTensorEvaluator451 Index userPlaneStride() const { return m_plane_strides; } userRowStrideCustomTensorEvaluator452 Index userRowStride() const { return m_row_strides; } userColStrideCustomTensorEvaluator453 Index userColStride() const { return m_col_strides; } userInPlaneStrideCustomTensorEvaluator454 Index userInPlaneStride() const { return m_in_plane_strides; } userInRowStrideCustomTensorEvaluator455 Index userInRowStride() const { return m_in_row_strides; } userInColStrideCustomTensorEvaluator456 Index userInColStride() const { return m_in_col_strides; } planeInflateStrideCustomTensorEvaluator457 Index planeInflateStride() const { return m_plane_inflate_strides; } rowInflateStrideCustomTensorEvaluator458 Index rowInflateStride() const { return m_row_inflate_strides; } colInflateStrideCustomTensorEvaluator459 Index colInflateStride() const { return m_col_inflate_strides; } 460 461 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffCustomTensorEvaluator462 coeff(const array<Index, NumDims>& coords) const { 463 // ColMajor 464 // 0: depth, 1: patch_planes, 2: patch_rows, 3: patch_cols, 4: number of 465 // patches, 5: batches 466 // RowMajor 467 // 0: batches, 1: number of patches, 2: patch_cols , 3: patch_rows, 4: 468 // patch_planes, 5: depth 469 const Index patch3DIndex = 470 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 4 : 1]; 471 const Index colOffset = 472 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 3 : 2]; 473 const Index rowOffset = 474 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 3]; 475 const Index planeOffset = 476 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 4]; 477 478 array<Index, NumDims - 1> inputCoords; 479 480 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 481 const Index inputCol = colIndex * m_col_strides + 482 colOffset * m_in_col_strides - m_colPaddingLeft; 483 const Index origInputCol = 484 (m_col_inflate_strides == 1) 485 ? inputCol 486 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 487 if (inputCol < 0 || inputCol >= m_input_cols_eff || 488 ((m_col_inflate_strides != 1) && 489 (inputCol != origInputCol * m_col_inflate_strides))) { 490 return Scalar(m_paddingValue); 491 } 492 493 const Index rowIndex = 494 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 495 const Index inputRow = rowIndex * m_row_strides + 496 rowOffset * m_in_row_strides - m_rowPaddingTop; 497 const Index origInputRow = 498 (m_row_inflate_strides == 1) 499 ? inputRow 500 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 501 if (inputRow < 0 || inputRow >= m_input_rows_eff || 502 ((m_row_inflate_strides != 1) && 503 (inputRow != origInputRow * m_row_inflate_strides))) { 504 return Scalar(m_paddingValue); 505 } 506 507 const Index planeIndex = 508 patch3DIndex - colIndex * m_outputPlanesRows - rowIndex * m_outputRows; 509 const Index inputPlane = planeIndex * m_plane_strides + 510 planeOffset * m_in_plane_strides - 511 m_planePaddingTop; 512 const Index origInputPlane = 513 (m_plane_inflate_strides == 1) 514 ? inputPlane 515 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 516 if (inputPlane < 0 || inputPlane >= m_input_planes_eff || 517 ((m_plane_inflate_strides != 1) && 518 (inputPlane != origInputPlane * m_plane_inflate_strides))) { 519 return Scalar(m_paddingValue); 520 } 521 522 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 523 inputCoords[0] = coords[0]; // depth 524 inputCoords[1] = origInputPlane; 525 inputCoords[2] = origInputRow; 526 inputCoords[3] = origInputCol; 527 inputCoords[4] = coords[5]; // batch 528 } else { 529 inputCoords[4] = coords[5]; // depth 530 inputCoords[3] = origInputPlane; 531 inputCoords[2] = origInputRow; 532 inputCoords[1] = origInputCol; 533 inputCoords[0] = coords[0]; // batch 534 } 535 if (TensorEvaluator<ArgType, Device>::CoordAccess) { 536 return m_impl.coeff(inputCoords); 537 } else { 538 Index inputIndex; 539 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 540 inputIndex = inputCoords[4] * m_otherInputStride + 541 inputCoords[3] * m_colInputStride + 542 inputCoords[2] * m_rowInputStride + 543 inputCoords[1] * m_planeInputStride + inputCoords[0]; 544 } else { 545 inputIndex = inputCoords[0] * m_otherInputStride + 546 inputCoords[1] * m_colInputStride + 547 inputCoords[2] * m_rowInputStride + 548 inputCoords[3] * m_planeInputStride + inputCoords[4]; 549 } 550 return m_impl.coeff(inputIndex); 551 } 552 } 553 554 protected: 555 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetWithPossibleZeroCustomTensorEvaluator556 packetWithPossibleZero(Index index) const { 557 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type 558 values[PacketSize]; 559 for (int i = 0; i < PacketSize; ++i) { 560 values[i] = coeff(index + i); 561 } 562 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 563 return rslt; 564 } 565 566 Dimensions m_dimensions; 567 568 // Parameters passed to the constructor. 569 Index m_plane_strides; 570 Index m_row_strides; 571 Index m_col_strides; 572 573 Index m_outputPlanes; 574 Index m_outputRows; 575 Index m_outputCols; 576 577 Index m_planePaddingTop; 578 Index m_rowPaddingTop; 579 Index m_colPaddingLeft; 580 581 Index m_in_plane_strides; 582 Index m_in_row_strides; 583 Index m_in_col_strides; 584 585 Index m_plane_inflate_strides; 586 Index m_row_inflate_strides; 587 Index m_col_inflate_strides; 588 589 // Cached input size. 590 Index m_inputDepth; 591 Index m_inputPlanes; 592 Index m_inputRows; 593 Index m_inputCols; 594 595 // Other cached variables. 596 Index m_outputPlanesRows; 597 598 // Effective input/patch post-inflation size. 599 Index m_input_planes_eff; 600 Index m_input_rows_eff; 601 Index m_input_cols_eff; 602 Index m_patch_planes_eff; 603 Index m_patch_rows_eff; 604 Index m_patch_cols_eff; 605 606 // Strides for the output tensor. 607 Index m_otherStride; 608 Index m_patchStride; 609 Index m_rowStride; 610 Index m_colStride; 611 612 // Strides for the input tensor. 613 Index m_planeInputStride; 614 Index m_rowInputStride; 615 Index m_colInputStride; 616 Index m_otherInputStride; 617 618 internal::TensorIntDivisor<Index> m_fastOtherStride; 619 internal::TensorIntDivisor<Index> m_fastPatchStride; 620 internal::TensorIntDivisor<Index> m_fastColStride; 621 internal::TensorIntDivisor<Index> m_fastRowStride; 622 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 623 internal::TensorIntDivisor<Index> m_fastInputRowStride; 624 internal::TensorIntDivisor<Index> m_fastInputColStride; 625 internal::TensorIntDivisor<Index> m_fastInputColsEff; 626 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 627 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 628 internal::TensorIntDivisor<Index> m_fastOutputDepth; 629 630 Scalar m_paddingValue; 631 632 TensorEvaluator<ArgType, Device> m_impl; 633 }; 634 635 // Override the default TensorEvaluator for TensorVolumePatchOp for CPU. 636 #define OVERRIDE_EVALUATOR(Device) \ 637 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, \ 638 typename ArgType> \ 639 struct TensorEvaluator< \ 640 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device> \ 641 : public CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device> { \ 642 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator( \ 643 const typename CustomTensorEvaluator<Planes, Rows, Cols, ArgType, \ 644 Device>::XprType& op, \ 645 const Device& device) \ 646 : CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device>( \ 647 op, device) {} \ 648 }; 649 650 OVERRIDE_EVALUATOR(Eigen::ThreadPoolDevice); 651 OVERRIDE_EVALUATOR(Eigen::DefaultDevice); 652 653 #undef OVERRIDE_EVALUATOR 654 655 }; // namespace Eigen 656 657 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 658