1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/kernels/eigen_volume_patch.h" 21 22 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 23 #include "tensorflow/core/kernels/eigen_contraction_kernel.h" 24 #endif 25 26 namespace Eigen { 27 28 namespace internal { 29 30 // WARNING: Most of the code here implicitly assumes that the matrix is in 31 // ColMajor layout. This is guaranteed by the tensor contraction (see 32 // TensorContraction.h). 33 // 34 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 35 // We don't want to actually extract volume patches and reshape the result into 36 // a matrix (this involves allocating huge extra memory), so the patch 37 // extraction and reshape operations are implicit. 38 // 39 // TensorContractionInputMapper takes a matrix index and returns the coefficient 40 // (or the packet) of the "virtual tensor", that would be at that index if we 41 // were to actually reshape the result of patch extraction. 42 // 43 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 44 // at the given vertical and horizontal offsets. 45 // 46 // "Virtual matrix" dimensions: 47 // *0: kernelChannels * kernelPlanes * kernelRows * kernelCols 48 // 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...) 49 // 50 // *) extracted patches are continuous in memory (innermost dimension assuming 51 // col major layout) 52 // 53 // With this dimensions: 54 // row - offset within a single patch (in code: patchId) 55 // col - index of the extracted patch (in code: patchIndex) 56 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 57 // 58 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 59 typename ArgType, typename Device, typename Scalar_, typename Index, 60 typename nocontract_t, typename contract_t, int Side, int packet_size, 61 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 62 class TensorContractionInputMapper< 63 Scalar_, Index, Side, 64 TensorEvaluator<const TensorReshapingOp<NewDimension, 65 const TensorVolumePatchOp< 66 Planes, Rows, Cols, ArgType> >, 67 Device>, 68 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 69 inner_dim_reordered, Alignment> { 70 public: 71 typedef Scalar_ Scalar; 72 typedef TensorContractionInputMapper< 73 Scalar, Index, Side, 74 TensorEvaluator<const TensorReshapingOp< 75 NewDimension, const TensorVolumePatchOp< 76 Planes, Rows, Cols, ArgType> >, 77 Device>, 78 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 79 inner_dim_reordered, Alignment> 80 Self; 81 typedef TensorContractionSubMapper< 82 Scalar, Index, Side, 83 TensorEvaluator<const TensorReshapingOp< 84 NewDimension, const TensorVolumePatchOp< 85 Planes, Rows, Cols, ArgType> >, 86 Device>, 87 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 88 inner_dim_reordered, Alignment> 89 SubMapper; 90 typedef SubMapper VectorMapper; 91 typedef SubMapper LinearMapper; 92 typedef typename packet_traits<Scalar>::type Packet; 93 94 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)95 TensorContractionInputMapper( 96 const TensorEvaluator< 97 const TensorReshapingOp< 98 NewDimension, 99 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >, 100 Device>& tensor, 101 const nocontract_t&, const nocontract_t&, const contract_t&, 102 const contract_t&) 103 : m_impl(tensor.impl().impl()) { 104 if (internal::traits<ArgType>::Layout == ColMajor) { 105 m_patch_depth = tensor.impl().dimensions()[0]; 106 m_patch_planes = tensor.impl().dimensions()[1]; 107 m_patch_rows = tensor.impl().dimensions()[2]; 108 m_patch_cols = tensor.impl().dimensions()[3]; 109 m_num_patches = tensor.impl().dimensions()[4]; 110 } else { 111 const int NumDims = tensor.impl().dimensions().size(); 112 m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; 113 m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; 114 m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; 115 m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; 116 m_num_patches = tensor.impl().dimensions()[NumDims - 5]; 117 } 118 119 // Strides for navigating through the single patch. 120 m_patch_plane_stride = m_patch_depth; 121 m_patch_row_stride = m_patch_planes * m_patch_plane_stride; 122 m_patch_col_stride = m_patch_rows * m_patch_row_stride; 123 124 // Strides for the output tensor. 125 // IMPORTANT: These strides are used to locate an element in a patch at a 126 // depth zero (channel), which is not quite the same as "traditional" 127 // stride. 128 m_rowStride = m_patch_planes; 129 m_colStride = m_patch_rows * m_rowStride; 130 m_patchStride = m_colStride * m_patch_cols * m_patch_depth; 131 m_otherStride = m_patchStride * m_num_patches; 132 133 m_outputPlanes = tensor.impl().outputPlanes(); 134 m_outputRows = tensor.impl().outputRows(); 135 m_outputCols = tensor.impl().outputCols(); 136 137 m_outputPlanesRows = m_outputPlanes * m_outputRows; 138 139 m_plane_strides = tensor.impl().userPlaneStride(); 140 m_row_strides = tensor.impl().userRowStride(); 141 m_col_strides = tensor.impl().userColStride(); 142 143 m_in_plane_strides = tensor.impl().userInPlaneStride(); 144 m_in_row_strides = tensor.impl().userInRowStride(); 145 m_in_col_strides = tensor.impl().userInColStride(); 146 147 m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); 148 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 149 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 150 151 if (internal::traits<ArgType>::Layout == ColMajor) { 152 m_inputDepth = tensor.impl().impl().dimensions()[0]; 153 m_inputPlanes = tensor.impl().impl().dimensions()[1]; 154 m_inputRows = tensor.impl().impl().dimensions()[2]; 155 m_inputCols = tensor.impl().impl().dimensions()[3]; 156 } else { 157 const int NumDims = tensor.impl().impl().dimensions().size(); 158 m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; 159 m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; 160 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; 161 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; 162 } 163 164 // Strides for navigating through the input tensor. 165 m_planeInputStride = m_inputDepth; 166 m_rowInputStride = m_inputDepth * m_inputPlanes; 167 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 168 m_patchInputStride = 169 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 170 171 m_planePaddingTop = tensor.impl().planePaddingTop(); 172 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 173 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 174 175 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 176 177 m_fastPatchPlaneStride = 178 internal::TensorIntDivisor<Index>(m_patch_plane_stride); 179 m_fastPatchRowStride = 180 internal::TensorIntDivisor<Index>(m_patch_row_stride); 181 m_fastPatchColStride = 182 internal::TensorIntDivisor<Index>(m_patch_col_stride); 183 184 m_fastInputPlaneStride = 185 internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides); 186 m_fastInputRowStride = 187 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 188 m_fastInputColStride = 189 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 190 191 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 192 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 193 194 m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth); 195 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 196 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 197 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 198 m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols); 199 200 m_fastOutputPlanesRows = 201 internal::TensorIntDivisor<Index>(m_outputPlanesRows); 202 } 203 204 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)205 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 206 : m_impl(base_mapper.m_impl) { 207 m_patch_depth = base_mapper.m_patch_depth; 208 m_patch_planes = base_mapper.m_patch_planes; 209 m_patch_rows = base_mapper.m_patch_rows; 210 m_patch_cols = base_mapper.m_patch_cols; 211 m_num_patches = base_mapper.m_num_patches; 212 213 m_patch_plane_stride = base_mapper.m_patch_plane_stride; 214 m_patch_row_stride = base_mapper.m_patch_row_stride; 215 m_patch_col_stride = base_mapper.m_patch_col_stride; 216 217 m_rowStride = base_mapper.m_rowStride; 218 m_colStride = base_mapper.m_colStride; 219 m_patchStride = base_mapper.m_patchStride; 220 m_otherStride = base_mapper.m_otherStride; 221 222 m_planeInputStride = base_mapper.m_planeInputStride; 223 m_rowInputStride = base_mapper.m_rowInputStride; 224 m_colInputStride = base_mapper.m_colInputStride; 225 m_patchInputStride = base_mapper.m_patchInputStride; 226 m_otherInputStride = base_mapper.m_otherInputStride; 227 228 m_inputDepth = base_mapper.m_inputDepth; 229 m_inputPlanes = base_mapper.m_inputPlanes; 230 m_inputRows = base_mapper.m_inputRows; 231 m_inputCols = base_mapper.m_inputCols; 232 233 m_outputPlanes = base_mapper.m_outputPlanes; 234 m_outputRows = base_mapper.m_outputRows; 235 m_outputCols = base_mapper.m_outputCols; 236 237 m_plane_strides = base_mapper.m_plane_strides; 238 m_row_strides = base_mapper.m_row_strides; 239 m_col_strides = base_mapper.m_col_strides; 240 241 m_in_plane_strides = base_mapper.m_in_plane_strides; 242 m_in_row_strides = base_mapper.m_in_row_strides; 243 m_in_col_strides = base_mapper.m_in_col_strides; 244 245 m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; 246 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 247 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 248 249 m_planePaddingTop = base_mapper.m_planePaddingTop; 250 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 251 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 252 253 m_outputPlanesRows = base_mapper.m_outputPlanesRows; 254 255 m_fastNumPatches = base_mapper.m_fastNumPatches; 256 m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride; 257 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 258 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 259 m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; 260 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 261 m_fastInputColStride = base_mapper.m_fastInputColStride; 262 m_fastRowStride = base_mapper.m_fastRowStride; 263 m_fastColStride = base_mapper.m_fastColStride; 264 m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; 265 m_fastOutputRows = base_mapper.m_fastOutputRows; 266 m_fastOutputCols = base_mapper.m_fastOutputCols; 267 m_fastDimZero = base_mapper.m_fastDimZero; 268 m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; 269 } 270 271 // If true, turns off some optimizations for loading packets since the image 272 // patches are "non-standard" such as there are non-trivial strides or 273 // inflations in the input. 274 EIGEN_DEVICE_FUNC nonStandardPatches()275 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 276 return m_in_plane_strides != 1 || m_in_row_strides != 1 || 277 m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || 278 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 279 } 280 281 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)282 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 283 return SubMapper(*this, i, j); 284 } 285 286 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)287 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 288 return LinearMapper(*this, i, j); 289 } 290 291 EIGEN_DEVICE_FUNC operator()292 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 293 Index planeIndex, rowIndex, colIndex, otherIndex; 294 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 295 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 296 } 297 298 // Load the coefficient at the patchIndex location instead of the usual 299 // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the 300 // gpu code. 301 EIGEN_DEVICE_FUNC operator()302 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 303 Index planeIndex, rowIndex, colIndex, otherIndex; 304 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 305 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 306 } 307 308 EIGEN_DEVICE_FUNC loadPacket(Index row)309 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 310 Index planeIndex, rowIndex, colIndex, otherIndex; 311 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 312 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 313 } 314 315 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 316 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 317 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)318 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 319 Index planeIndex, rowIndex, colIndex, otherIndex; 320 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 321 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 322 } 323 324 EIGEN_DEVICE_FUNC impl()325 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 326 return m_impl; 327 } 328 329 EIGEN_DEVICE_FUNC patchDepth()330 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; } 331 EIGEN_DEVICE_FUNC patchPlanes()332 EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; } 333 EIGEN_DEVICE_FUNC patchRows()334 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } 335 EIGEN_DEVICE_FUNC patchCols()336 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 337 338 private: 339 friend class TensorContractionSubMapper< 340 Scalar, Index, Side, 341 TensorEvaluator<const TensorReshapingOp< 342 NewDimension, const TensorVolumePatchOp< 343 Planes, Rows, Cols, ArgType> >, 344 Device>, 345 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 346 inner_dim_reordered, Alignment>; 347 348 // Load coefficient from a patch specified by the "within patch offset" 349 // (patchId) and the precomputed indices of the first element of the patch. 350 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)351 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, 352 Index rowIndex, Index colIndex, 353 Index otherIndex) const { 354 // Find the offset of the element wrt the location of the first element. 355 const Index patchOffset = patchId / m_fastDimZero; 356 357 const Index colOffset = patchOffset / m_fastColStride; 358 const Index inputCol = colIndex + colOffset * m_in_col_strides; 359 const Index origInputCol = 360 (m_patch_col_inflate_strides == 1) 361 ? inputCol 362 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 363 364 const Index rowOffset = 365 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 366 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 367 const Index origInputRow = 368 (m_patch_row_inflate_strides == 1) 369 ? inputRow 370 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 371 372 const Index planeOffset = 373 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 374 const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; 375 const Index origInputPlane = 376 (m_patch_plane_inflate_strides == 1) 377 ? inputPlane 378 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 379 380 if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || 381 origInputCol >= m_inputCols || origInputRow >= m_inputRows || 382 origInputPlane >= m_inputPlanes || 383 (inputCol != origInputCol * m_patch_col_inflate_strides) || 384 (inputRow != origInputRow * m_patch_row_inflate_strides) || 385 (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { 386 return Scalar(0); 387 } 388 389 const Index depth = patchId - patchOffset * patchDepth(); 390 const Index inputIndex = depth + origInputPlane * m_planeInputStride + 391 origInputRow * m_rowInputStride + 392 origInputCol * m_colInputStride + otherIndex; 393 394 return m_impl.coeff(inputIndex); 395 } 396 397 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 398 // and `in_strides` equal to 1 (template specialization without templates). 399 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)400 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, 401 Index rowIndex, Index colIndex, 402 Index otherIndex) const { 403 eigen_assert(!nonStandardPatches()); 404 405 // Find the offset of the element wrt the location of the first element. 406 const Index patchOffset = patchId / m_fastDimZero; 407 408 const Index colOffset = patchOffset / m_fastColStride; 409 const Index rowOffset = 410 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 411 const Index planeOffset = 412 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 413 414 const Index inputCol = colIndex + colOffset; 415 const Index inputRow = rowIndex + rowOffset; 416 const Index inputPlane = planeIndex + planeOffset; 417 418 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 419 inputRow >= m_inputRows || inputPlane < 0 || 420 inputPlane >= m_inputPlanes) { 421 return Scalar(0); 422 } 423 424 const Index depth = patchId - patchOffset * patchDepth(); 425 const Index inputIndex = depth + inputPlane * m_planeInputStride + 426 inputRow * m_rowInputStride + 427 inputCol * m_colInputStride + otherIndex; 428 429 return m_impl.coeff(inputIndex); 430 } 431 432 // Load packet from a patch specified by the "within patch offset" 433 // (patchId) and the precomputed indices of the first element of the patch. 434 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)435 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, 436 Index rowIndex, Index colIndex, 437 Index otherIndex) const { 438 const Index packetSize = internal::unpacket_traits<Packet>::size; 439 440 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 441 eigen_assert(patchId < 442 patchDepth() * patchPlanes() * patchRows() * patchCols()); 443 444 if (nonStandardPatches()) { 445 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 446 otherIndex); 447 } 448 return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex, 449 otherIndex); 450 } 451 452 EIGEN_DEVICE_FUNC loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)453 EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex, 454 Index rowIndex, Index colIndex, 455 Index otherIndex) const { 456 const Index packetSize = internal::unpacket_traits<Packet>::size; 457 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 458 eigen_assert(patchId < 459 patchDepth() * patchPlanes() * patchRows() * patchCols()); 460 eigen_assert(!nonStandardPatches()); 461 462 if ((patchDepth() % packetSize) == 0) { 463 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, 464 otherIndex); 465 } else { 466 // Offsets and input calculation here are identical to 467 // loadCoeffStandard(...), but repeated twice. 468 469 const Index patchOffsets[2] = { 470 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 471 472 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 473 patchOffsets[1] / m_fastColStride}; 474 eigen_assert(colOffsets[0] <= colOffsets[1]); 475 476 const Index inputCols[2] = {colIndex + colOffsets[0], 477 colIndex + colOffsets[1]}; 478 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 479 return internal::pset1<Packet>(Scalar(0)); 480 } 481 482 if (inputCols[0] == inputCols[1]) { 483 const Index rowOffsets[2] = { 484 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 485 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 486 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 487 const Index inputRows[2] = {rowIndex + rowOffsets[0], 488 rowIndex + rowOffsets[1]}; 489 490 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 491 return internal::pset1<Packet>(Scalar(0)); 492 } 493 494 if (inputRows[0] == inputRows[1]) { 495 const Index planeOffsets[2] = { 496 patchOffsets[0] - colOffsets[0] * m_colStride - 497 rowOffsets[0] * m_rowStride, 498 patchOffsets[1] - colOffsets[1] * m_colStride - 499 rowOffsets[1] * m_rowStride}; 500 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 501 const Index inputPlanes[2] = {planeIndex + planeOffsets[0], 502 planeIndex + planeOffsets[1]}; 503 504 if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { 505 return internal::pset1<Packet>(Scalar(0)); 506 } 507 508 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 509 const Index depth = patchId - patchOffsets[0] * patchDepth(); 510 const Index inputIndex = 511 depth + inputPlanes[0] * m_planeInputStride + 512 inputRows[0] * m_rowInputStride + 513 inputCols[0] * m_colInputStride + otherIndex; 514 return m_impl.template packet<Unaligned>(inputIndex); 515 } 516 } 517 } 518 } 519 520 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 521 otherIndex); 522 } 523 524 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)525 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, 526 Index rowIndex, Index colIndex, 527 Index otherIndex) const { 528 const Index packetSize = internal::unpacket_traits<Packet>::size; 529 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 530 eigen_assert(patchId < 531 patchDepth() * patchPlanes() * patchRows() * patchCols()); 532 533 eigen_assert(!nonStandardPatches()); 534 eigen_assert((patchDepth() % packetSize) == 0); 535 536 // Find the offset of the element wrt the location of the first element. 537 const Index patchOffset = patchId / m_fastDimZero; 538 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 539 540 const Index colOffset = patchOffset / m_fastColStride; 541 const Index rowOffset = 542 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 543 const Index planeOffset = 544 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 545 546 const Index inputCol = colIndex + colOffset; 547 const Index inputRow = rowIndex + rowOffset; 548 const Index inputPlane = planeIndex + planeOffset; 549 550 if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || 551 inputCol >= m_inputCols || inputRow >= m_inputRows || 552 inputPlane >= m_inputPlanes) { 553 return internal::pset1<Packet>(Scalar(0)); 554 } 555 556 const Index depth = patchId - patchOffset * patchDepth(); 557 const Index inputIndex = depth + inputPlane * m_planeInputStride + 558 inputRow * m_rowInputStride + 559 inputCol * m_colInputStride + otherIndex; 560 return m_impl.template packet<Unaligned>(inputIndex); 561 } 562 563 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)564 packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, 565 Index colIndex, Index otherIndex) const { 566 const int packetSize = internal::unpacket_traits<Packet>::size; 567 EIGEN_ALIGN_MAX 568 typename internal::remove_const<Scalar>::type values[packetSize]; 569 for (int i = 0; i < packetSize; ++i) { 570 values[i] = 571 loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); 572 } 573 Packet rslt = internal::pload<Packet>(values); 574 return rslt; 575 } 576 577 // Precompute the indices (plane, row, col, other) of the first element of 578 // the given patch index, within the output tensor of the TensorVolumePatchOp. computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 580 Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, 581 Index& otherIndex) const { 582 const size_t NumInputDims = array_size< 583 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 584 585 // Check if patchIndex might contain batch and other dimensions. 586 otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; 587 588 // Compute index of the patch within the batch (and other dimensions). 589 const Index patch3DIndex = (NumInputDims == 4) 590 ? patchIndex 591 : (patchIndex - otherIndex * m_num_patches); 592 593 otherIndex *= m_patchInputStride; 594 595 colIndex = patch3DIndex / m_fastOutputPlanesRows; 596 rowIndex = 597 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 598 planeIndex = 599 patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; 600 601 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 602 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 603 planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; 604 } 605 606 Index m_patch_depth; // number of channels in the patch 607 Index m_patch_planes; // number of planes in the patch 608 Index m_patch_rows; // number of rows in the patch 609 Index m_patch_cols; // number of columns in the patch 610 Index m_num_patches; // number of patches to extract 611 612 // Strides for navigating through the single patch. 613 Index m_patch_plane_stride; 614 Index m_patch_row_stride; 615 Index m_patch_col_stride; 616 617 // Strides for the output tensor (depth is not the part of the stride). 618 Index m_rowStride; 619 Index m_colStride; 620 Index m_patchStride; 621 Index m_otherStride; 622 623 Index m_planeInputStride; // Plane stride in the input tensor 624 Index m_rowInputStride; // Row stride in the input tensor 625 Index m_colInputStride; // Col stride in the input tensor 626 Index m_patchInputStride; // Patch stride in the input tensor 627 Index m_otherInputStride; 628 629 Index m_inputDepth; // Depth of the input tensor 630 Index m_inputPlanes; // Number of planes in the input tensor 631 Index m_inputRows; // Number of rows in the input tensor 632 Index m_inputCols; // Number of cols in the input tensor 633 634 Index m_outputPlanes; // Number of output planes 635 Index m_outputRows; // Number of output rows 636 Index m_outputCols; // Number of output cols 637 Index m_outputPlanesRows; // Cached outputPlanes * outputRows. 638 639 Index m_plane_strides; // User specified plane stride 640 Index m_row_strides; // User specified row stride 641 Index m_col_strides; // User specified col stride 642 643 // User specified plane/row/col atrous convolution strides. 644 Index m_in_plane_strides; 645 Index m_in_row_strides; 646 Index m_in_col_strides; 647 648 // User specified plane/row/col inflation strides in the image patch. 649 Index m_patch_plane_inflate_strides; 650 Index m_patch_row_inflate_strides; 651 Index m_patch_col_inflate_strides; 652 653 Index m_planePaddingTop; // Plane padding 654 Index m_rowPaddingTop; // Row padding 655 Index m_colPaddingLeft; // Column padding 656 657 // Fast representation of various divisors. 658 internal::TensorIntDivisor<Index> m_fastNumPatches; 659 660 internal::TensorIntDivisor<Index> m_fastPatchPlaneStride; 661 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 662 internal::TensorIntDivisor<Index> m_fastPatchColStride; 663 664 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 665 internal::TensorIntDivisor<Index> m_fastInputRowStride; 666 internal::TensorIntDivisor<Index> m_fastInputColStride; 667 668 internal::TensorIntDivisor<Index> m_fastRowStride; 669 internal::TensorIntDivisor<Index> m_fastColStride; 670 671 internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth 672 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 673 internal::TensorIntDivisor<Index> m_fastOutputRows; 674 internal::TensorIntDivisor<Index> m_fastOutputCols; 675 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 676 677 const TensorEvaluator<ArgType, Device> m_impl; 678 }; 679 680 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 681 typename ArgType, typename Device, typename Scalar, typename Index, 682 typename nocontract_t, typename contract_t, int Side, int packet_size, 683 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 684 class TensorContractionSubMapper< 685 Scalar, Index, Side, 686 TensorEvaluator<const TensorReshapingOp<NewDimension, 687 const TensorVolumePatchOp< 688 Planes, Rows, Cols, ArgType> >, 689 Device>, 690 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 691 inner_dim_reordered, Alignment> { 692 public: 693 typedef typename packet_traits<Scalar>::type Packet; 694 typedef typename packet_traits<Scalar>::half HalfPacket; 695 696 typedef TensorContractionInputMapper< 697 Scalar, Index, Side, 698 TensorEvaluator<const TensorReshapingOp< 699 NewDimension, const TensorVolumePatchOp< 700 Planes, Rows, Cols, ArgType> >, 701 Device>, 702 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 703 inner_dim_reordered, Alignment> 704 ParentMapper; 705 typedef TensorContractionSubMapper< 706 Scalar, Index, Side, 707 TensorEvaluator<const TensorReshapingOp< 708 NewDimension, const TensorVolumePatchOp< 709 Planes, Rows, Cols, ArgType> >, 710 Device>, 711 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 712 inner_dim_reordered, Alignment> 713 Self; 714 typedef Self LinearMapper; 715 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)716 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 717 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 718 : m_base_mapper(base_mapper), 719 m_depth_offset(vert_offset), 720 m_col_offset(horiz_offset) { 721 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 722 m_colIndex, m_otherIndex); 723 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)724 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 725 const Self& base_mapper, Index vert_offset, Index horiz_offset) 726 : m_base_mapper(base_mapper.m_base_mapper), 727 m_depth_offset(vert_offset + base_mapper.m_depth_offset), 728 m_col_offset(horiz_offset + base_mapper.m_col_offset) { 729 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 730 m_colIndex, m_otherIndex); 731 } operator()732 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 733 return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, 734 m_colIndex, m_otherIndex); 735 } operator()736 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 737 Index j) const { 738 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 739 } 740 loadPacket(Index i)741 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 742 return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, 743 m_rowIndex, m_colIndex, m_otherIndex); 744 } loadPacket(Index i,Index j)745 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 746 Index j) const { 747 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 748 j + m_col_offset); 749 } 750 751 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)752 loadCoeffStandard(Index i) const { 753 return m_base_mapper.loadCoeffStandard( 754 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 755 } 756 loadPacketFast(Index i)757 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 758 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, 759 m_rowIndex, m_colIndex, m_otherIndex); 760 } 761 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)762 loadPacketStandard(Index i) const { 763 return m_base_mapper.loadPacketStandard( 764 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 765 } 766 template <typename Packet> aligned(Index)767 EIGEN_DEVICE_FUNC bool aligned(Index) const { 768 return false; 769 } 770 771 EIGEN_DEVICE_FUNC nonStandardPatches()772 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 773 return m_base_mapper.nonStandardPatches(); 774 } 775 776 // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row, 777 // plane and depth index respectively that fits into the peeled_k elements 778 // starting at m_depth_offset. 779 780 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)781 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 782 const Index max_col = 783 fastPatchColStride().divide(m_depth_offset + peeled_k); 784 return std::min<Index>(1 + max_col, patchCols()); 785 } 786 787 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)788 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 789 const Index col) const { 790 const Index max_row = fastPatchRowStride().divide( 791 m_depth_offset + peeled_k - col * patchColStride()); 792 return std::min<Index>(1 + max_row, patchRows()); 793 } 794 795 EIGEN_DEVICE_FUNC maxPlane(const Index peeled_k,const Index col,const Index row)796 EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col, 797 const Index row) const { 798 const Index max_plane = fastPatchPlaneStride().divide( 799 m_depth_offset + peeled_k - col * patchColStride() - 800 row * patchRowStride()); 801 return std::min<Index>(1 + max_plane, patchPlanes()); 802 } 803 804 // MaxDepth uses only the remaining number of elements in the peeled_k. 805 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)806 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 807 const Index start_depth) const { 808 return std::min<Index>(start_depth + num_elements, patchDepth()); 809 } 810 811 // Every register matters in this code, so sometimes to prevent register 812 // spilling, instead of the variable that you would expect to see, we use 813 // another one, that is guaranteed to have the same value. E.g. patch depth is 814 // always the same as input depth, and it's also the same as input plane 815 // stride. Bunch of other parameters have similar relations. 816 817 typedef internal::TensorIntDivisor<Index> IndexDivisor; 818 819 EIGEN_DEVICE_FUNC patchDepth()820 EIGEN_ALWAYS_INLINE Index patchDepth() const { 821 eigen_assert(m_base_mapper.m_patch_depth == 822 m_base_mapper.m_planeInputStride && 823 "Patch depth must be equal to plane input stride."); 824 return m_base_mapper.m_planeInputStride; 825 } 826 827 EIGEN_DEVICE_FUNC patchPlanes()828 EIGEN_ALWAYS_INLINE Index patchPlanes() const { 829 eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride && 830 "Patch planes must be equal to row stride."); 831 return m_base_mapper.m_rowStride; 832 } 833 EIGEN_DEVICE_FUNC patchRows()834 EIGEN_ALWAYS_INLINE Index patchRows() const { 835 return m_base_mapper.m_patch_rows; 836 } 837 EIGEN_DEVICE_FUNC patchCols()838 EIGEN_ALWAYS_INLINE Index patchCols() const { 839 return m_base_mapper.m_patch_cols; 840 } 841 842 EIGEN_DEVICE_FUNC patchPlaneStride()843 EIGEN_ALWAYS_INLINE Index patchPlaneStride() const { 844 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 845 "Patch depth must be equal to patch plane stride."); 846 return patchDepth(); 847 } 848 EIGEN_DEVICE_FUNC patchRowStride()849 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 850 return m_base_mapper.m_patch_row_stride; 851 } 852 EIGEN_DEVICE_FUNC patchColStride()853 EIGEN_ALWAYS_INLINE Index patchColStride() const { 854 return m_base_mapper.m_patch_col_stride; 855 } 856 857 EIGEN_DEVICE_FUNC fastPatchPlaneStride()858 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const { 859 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 860 "Patch depth must be equal to patch plane stride."); 861 return m_base_mapper.m_fastDimZero; // patch_depth 862 } 863 EIGEN_DEVICE_FUNC fastPatchRowStride()864 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 865 return m_base_mapper.m_fastPatchRowStride; 866 } 867 EIGEN_DEVICE_FUNC fastPatchColStride()868 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 869 return m_base_mapper.m_fastPatchColStride; 870 } 871 872 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)873 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 874 const Index baseIndex) const { 875 const Index inputIndex = depth + baseIndex; 876 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 877 } 878 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)879 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 880 const Index baseIndex) const { 881 const Index inputIndex = depth + baseIndex; 882 return m_base_mapper.m_impl.coeff(inputIndex); 883 } 884 885 EIGEN_DEVICE_FUNC padPlane(const Index plane)886 EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { 887 const Index p = m_planeIndex + plane; 888 return p < 0 || p >= m_base_mapper.m_inputPlanes; 889 } 890 EIGEN_DEVICE_FUNC padRow(const Index row)891 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 892 const Index r = m_rowIndex + row; 893 return r < 0 || r >= m_base_mapper.m_inputRows; 894 } 895 EIGEN_DEVICE_FUNC padCol(const Index col)896 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 897 const Index c = m_colIndex + col; 898 return c < 0 || c >= m_base_mapper.m_inputCols; 899 } 900 EIGEN_DEVICE_FUNC baseIndex(const Index plane,const Index row,const Index col)901 EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, 902 const Index col) const { 903 const Index p = m_planeIndex + plane; 904 const Index r = m_rowIndex + row; 905 const Index c = m_colIndex + col; 906 return p * m_base_mapper.m_planeInputStride + 907 r * m_base_mapper.m_rowInputStride + 908 c * m_base_mapper.m_colInputStride + m_otherIndex; 909 } 910 911 EIGEN_DEVICE_FUNC planeOffset()912 EIGEN_ALWAYS_INLINE Index planeOffset() const { 913 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 914 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 915 const Index rowOffset = 916 (patchOffset - colOffset * m_base_mapper.m_colStride) / 917 m_base_mapper.m_fastRowStride; 918 const Index planeOffset = patchOffset - 919 colOffset * m_base_mapper.m_colStride - 920 rowOffset * m_base_mapper.m_rowStride; 921 return planeOffset; 922 } 923 924 EIGEN_DEVICE_FUNC rowOffset()925 EIGEN_ALWAYS_INLINE Index rowOffset() const { 926 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 927 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 928 const Index rowOffset = 929 (patchOffset - colOffset * m_base_mapper.m_colStride) / 930 m_base_mapper.m_fastRowStride; 931 return rowOffset; 932 } 933 934 EIGEN_DEVICE_FUNC colOffset()935 EIGEN_ALWAYS_INLINE Index colOffset() const { 936 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 937 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 938 return colOffset; 939 } 940 941 EIGEN_DEVICE_FUNC depthOffset()942 EIGEN_ALWAYS_INLINE Index depthOffset() const { 943 return m_depth_offset % patchDepth(); 944 } 945 946 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)947 getLinearMapper(Index i, Index j) const { 948 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 949 } 950 951 private: 952 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 953 // performs better in benchmarks. 954 955 Index m_depth_offset; // First row in the input matrix 956 Index m_col_offset; // First col in the input matrix 957 958 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 959 // indices for the first element in a patch specified by col_offset 960 // (see computeBaseIndices(...) for details). 961 Index m_planeIndex; 962 Index m_rowIndex; 963 Index m_colIndex; 964 Index m_otherIndex; 965 }; 966 967 // Arrange a block of the right input matrix (in our case it's always a "virtual 968 // matrix" constructed from extracted volume patches) in contiguous memory. 969 // 970 // Given column major input (A0 beside A1 in memory): 971 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 972 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 973 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 974 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 975 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 976 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 977 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 978 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 979 // A8 ... 980 // ... 981 // 982 // *) A, B, C, ... - patches extracted from the original input. 983 // *) A0, A1, A2 ... - values from the same patch at different offsets. 984 // 985 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 986 // A0 B0 C0 D0 A1 B1 C1 D1 ... 987 // E0 F0 G0 H0 E1 F1 G1 H1 ... 988 // ... 989 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 990 // 991 // This traversal order must be the same as in default gemm_pack_rhs defined in 992 // GeneralBlockPanelKernel.h. 993 // 994 // *) nr - number of registers along the 'n' dimension. 995 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 996 // Multiplication" paper. 997 // 998 // TODO(ezhulenev): Add support for squeezing reads along two innermost 999 // dimensions (see eigen_spatial_convolutions). 1000 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1001 typename ArgType, typename Device, typename Scalar, typename Index, 1002 typename nocontract_t, typename contract_t, int packet_size, 1003 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 1004 int nr> 1005 struct gemm_pack_rhs< 1006 Scalar, Index, 1007 TensorContractionSubMapper< 1008 Scalar, Index, Rhs, 1009 TensorEvaluator<const TensorReshapingOp< 1010 NewDimension, const TensorVolumePatchOp< 1011 Planes, Rows, Cols, ArgType> >, 1012 Device>, 1013 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1014 inner_dim_reordered, Alignment>, 1015 nr, ColMajor, false, false> { 1016 typedef TensorContractionSubMapper< 1017 Scalar, Index, Rhs, 1018 TensorEvaluator<const TensorReshapingOp< 1019 NewDimension, const TensorVolumePatchOp< 1020 Planes, Rows, Cols, ArgType> >, 1021 Device>, 1022 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1023 inner_dim_reordered, Alignment> 1024 SubMapper; 1025 1026 typedef SubMapper DataMapper; 1027 typedef typename packet_traits<Scalar>::type Packet; 1028 1029 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1030 1031 EIGEN_DEVICE_FUNC 1032 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1033 Index depth, Index cols, Index stride = 0, 1034 Index offset = 0) const { 1035 eigen_assert(stride == 0); 1036 eigen_assert(offset == 0); 1037 1038 const Index packet_cols4 = (cols / 4) * 4; 1039 const Index peeled_k = (depth / packet_size) * packet_size; 1040 const bool non_standard_patches = rhs.nonStandardPatches(); 1041 1042 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1043 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1044 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1045 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1046 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1047 1048 Index k = 0; 1049 if ((packet_size % 4) == 0 && !non_standard_patches) { 1050 // FAST PATH: 1051 // Iterate over patch columns, rows and planes if we know that a single 1052 // packet do not span across multiple planes, rows or columns. 1053 if ((rhs.patchDepth() % packet_size) == 0) { 1054 const Index start_col = rhs.colOffset(); 1055 const Index max_col = rhs.maxCol(peeled_k); 1056 1057 for (Index c = start_col; c < max_col; ++c) { 1058 eigen_assert(k <= peeled_k); 1059 1060 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1061 const Index max_row = rhs.maxRow(peeled_k, c); 1062 1063 const bool pad_col0 = dm0.padCol(c); 1064 const bool pad_col1 = dm1.padCol(c); 1065 const bool pad_col2 = dm2.padCol(c); 1066 const bool pad_col3 = dm3.padCol(c); 1067 1068 for (Index r = start_row; r < max_row; ++r) { 1069 eigen_assert(k <= peeled_k); 1070 1071 const Index start_plane = ((c == start_col) && (r == start_row)) 1072 ? rhs.planeOffset() 1073 : 0; 1074 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1075 1076 const bool pad_row0 = pad_col0 || dm0.padRow(r); 1077 const bool pad_row1 = pad_col1 || dm1.padRow(r); 1078 const bool pad_row2 = pad_col2 || dm2.padRow(r); 1079 const bool pad_row3 = pad_col3 || dm3.padRow(r); 1080 1081 for (Index p = start_plane; p < max_plane; ++p) { 1082 eigen_assert(k <= peeled_k); 1083 1084 const bool pad0 = pad_row0 || dm0.padPlane(p); 1085 const bool pad1 = pad_row1 || dm1.padPlane(p); 1086 const bool pad2 = pad_row2 || dm2.padPlane(p); 1087 const bool pad3 = pad_row3 || dm3.padPlane(p); 1088 1089 const Index idx0 = dm0.baseIndex(p, r, c); 1090 const Index idx1 = dm1.baseIndex(p, r, c); 1091 const Index idx2 = dm2.baseIndex(p, r, c); 1092 const Index idx3 = dm3.baseIndex(p, r, c); 1093 1094 const Index start_depth = 1095 ((c == start_col) && (r == start_row) && (p == start_plane)) 1096 ? rhs.depthOffset() 1097 : 0; 1098 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1099 eigen_assert((max_depth - start_depth) % packet_size == 0); 1100 1101 for (Index d = start_depth; d < max_depth; d += packet_size) { 1102 eigen_assert(k < peeled_k); 1103 PacketBlock<Packet, 4> kernel; 1104 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1105 : rhs.packetNoPadding(d, idx0); 1106 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1107 : rhs.packetNoPadding(d, idx1); 1108 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 1109 : rhs.packetNoPadding(d, idx2); 1110 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 1111 : rhs.packetNoPadding(d, idx3); 1112 ptranspose(kernel); 1113 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1114 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1115 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1116 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1117 block += 4 * packet_size; 1118 k += packet_size; 1119 } 1120 } 1121 } 1122 } 1123 1124 // The loop above should fill peeled_k elements. 1125 eigen_assert(peeled_k == k); 1126 1127 } else { 1128 // Packet can span multiple planes, rows or columns, so we have to go 1129 // though the slower "standard" path. 1130 for (; k < peeled_k; k += packet_size) { 1131 PacketBlock<Packet, 4> kernel; 1132 kernel.packet[0] = dm0.loadPacketStandard(k); 1133 kernel.packet[1] = dm1.loadPacketStandard(k); 1134 kernel.packet[2] = dm2.loadPacketStandard(k); 1135 kernel.packet[3] = dm3.loadPacketStandard(k); 1136 ptranspose(kernel); 1137 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1138 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1139 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1140 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1141 block += 4 * packet_size; 1142 } 1143 } 1144 } 1145 1146 // Copy the remaining coefficients of the column block after the peeled_k. 1147 if (!non_standard_patches) { 1148 for (; k < depth; k++) { 1149 block[0] = dm0.loadCoeffStandard(k); 1150 block[1] = dm1.loadCoeffStandard(k); 1151 block[2] = dm2.loadCoeffStandard(k); 1152 block[3] = dm3.loadCoeffStandard(k); 1153 block += 4; 1154 } 1155 } else { 1156 for (; k < depth; k++) { 1157 block[0] = dm0(k); 1158 block[1] = dm1(k); 1159 block[2] = dm2(k); 1160 block[3] = dm3(k); 1161 block += 4; 1162 } 1163 } 1164 } 1165 1166 // Copy the remaining columns one at a time (nr==1). 1167 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1168 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1169 for (Index k = 0; k < depth; k++) { 1170 *block = dm0(k); 1171 block += 1; 1172 } 1173 } 1174 } 1175 }; 1176 1177 // Template specialization for packet_size = 2. We must special-case packet 1178 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1179 // 1180 // TODO(ezhulenev): Add support for squeezing reads along two innermost 1181 // dimensions (see eigen_spatial_convolutions). 1182 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1183 typename ArgType, typename Device, typename Scalar, typename Index, 1184 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1185 bool inner_dim_reordered, int Alignment, int nr> 1186 struct gemm_pack_rhs< 1187 Scalar, Index, 1188 TensorContractionSubMapper< 1189 Scalar, Index, Rhs, 1190 TensorEvaluator<const TensorReshapingOp< 1191 NewDimension, const TensorVolumePatchOp< 1192 Planes, Rows, Cols, ArgType> >, 1193 Device>, 1194 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1195 inner_dim_reordered, Alignment>, 1196 nr, ColMajor, false, false> { 1197 typedef TensorContractionSubMapper< 1198 Scalar, Index, Rhs, 1199 TensorEvaluator<const TensorReshapingOp< 1200 NewDimension, const TensorVolumePatchOp< 1201 Planes, Rows, Cols, ArgType> >, 1202 Device>, 1203 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1204 inner_dim_reordered, Alignment> 1205 SubMapper; 1206 typedef SubMapper DataMapper; 1207 typedef typename packet_traits<Scalar>::type Packet; 1208 1209 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1210 1211 EIGEN_DEVICE_FUNC 1212 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1213 Index depth, Index cols, Index stride = 0, 1214 Index offset = 0) const { 1215 eigen_assert(stride == 0); 1216 eigen_assert(offset == 0); 1217 1218 const int packet_size = 2; 1219 1220 const Index packet_cols4 = (cols / 4) * 4; 1221 const Index peeled_k = (depth / packet_size) * packet_size; 1222 const bool non_standard_patches = rhs.nonStandardPatches(); 1223 1224 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1225 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1226 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1227 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1228 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1229 1230 Index k = 0; 1231 if (!non_standard_patches) { 1232 // FAST PATH: 1233 // Iterate over patch columns, rows and planes if we know that a single 1234 // packet do not span across multiple planes, rows or columns. 1235 if ((rhs.patchDepth() % packet_size) == 0) { 1236 const Index start_col = rhs.colOffset(); 1237 const Index max_col = rhs.maxCol(peeled_k); 1238 1239 for (Index c = start_col; c < max_col; ++c) { 1240 eigen_assert(k <= peeled_k); 1241 1242 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1243 const Index max_row = rhs.maxRow(peeled_k, c); 1244 1245 const bool pad_col0 = dm0.padCol(c); 1246 const bool pad_col1 = dm1.padCol(c); 1247 const bool pad_col2 = dm2.padCol(c); 1248 const bool pad_col3 = dm3.padCol(c); 1249 1250 for (Index r = start_row; r < max_row; ++r) { 1251 eigen_assert(k <= peeled_k); 1252 1253 const Index start_plane = ((c == start_col) && (r == start_row)) 1254 ? rhs.planeOffset() 1255 : 0; 1256 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1257 1258 const bool pad_row0 = dm0.padRow(r); 1259 const bool pad_row1 = dm1.padRow(r); 1260 const bool pad_row2 = dm2.padRow(r); 1261 const bool pad_row3 = dm3.padRow(r); 1262 1263 for (Index p = start_plane; p < max_plane; ++p) { 1264 eigen_assert(k <= peeled_k); 1265 1266 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); 1267 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); 1268 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); 1269 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); 1270 1271 const Index idx0 = dm0.baseIndex(p, r, c); 1272 const Index idx1 = dm1.baseIndex(p, r, c); 1273 const Index idx2 = dm2.baseIndex(p, r, c); 1274 const Index idx3 = dm3.baseIndex(p, r, c); 1275 1276 const Index start_depth = 1277 ((c == start_col) && (r == start_row) && (p == start_plane)) 1278 ? rhs.depthOffset() 1279 : 0; 1280 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1281 eigen_assert((max_depth - start_depth) % packet_size == 0); 1282 1283 for (Index d = start_depth; d < max_depth; d += packet_size) { 1284 eigen_assert(k < peeled_k); 1285 PacketBlock<Packet, 2> kernel0; 1286 PacketBlock<Packet, 2> kernel1; 1287 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1288 : rhs.packetNoPadding(d, idx0); 1289 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1290 : rhs.packetNoPadding(d, idx1); 1291 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1292 : rhs.packetNoPadding(d, idx2); 1293 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1294 : rhs.packetNoPadding(d, idx3); 1295 ptranspose(kernel0); 1296 ptranspose(kernel1); 1297 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1298 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1299 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1300 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1301 block += 4 * packet_size; 1302 k += packet_size; 1303 } 1304 } 1305 } 1306 } 1307 1308 // The loop above should fill peeled_k elements. 1309 eigen_assert(peeled_k == k); 1310 1311 } else { 1312 for (; k < peeled_k; k += packet_size) { 1313 PacketBlock<Packet, 2> kernel0; 1314 PacketBlock<Packet, 2> kernel1; 1315 kernel0.packet[0] = dm0.loadPacketStandard(k); 1316 kernel0.packet[1] = dm1.loadPacketStandard(k); 1317 kernel1.packet[0] = dm2.loadPacketStandard(k); 1318 kernel1.packet[1] = dm3.loadPacketStandard(k); 1319 ptranspose(kernel0); 1320 ptranspose(kernel1); 1321 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1322 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1323 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1324 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1325 block += 4 * packet_size; 1326 } 1327 } 1328 } 1329 1330 // Copy the remaining coefficients of the column block after the peeled_k. 1331 if (!rhs.nonStandardPatches()) { 1332 for (; k < depth; k++) { 1333 block[0] = dm0.loadCoeffStandard(k); 1334 block[1] = dm1.loadCoeffStandard(k); 1335 block[2] = dm2.loadCoeffStandard(k); 1336 block[3] = dm3.loadCoeffStandard(k); 1337 block += 4; 1338 } 1339 } else { 1340 for (; k < depth; k++) { 1341 block[0] = dm0(k); 1342 block[1] = dm1(k); 1343 block[2] = dm2(k); 1344 block[3] = dm3(k); 1345 block += 4; 1346 } 1347 } 1348 } 1349 1350 // Copy the remaining columns one at a time (nr==1). 1351 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1352 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1353 for (Index k = 0; k < depth; k++) { 1354 *block = dm0(k); 1355 block += 1; 1356 } 1357 } 1358 } 1359 }; 1360 1361 // Special case for non-vectorized types such as float16 (packet_size = 1). 1362 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1363 typename ArgType, typename Device, typename Scalar, typename Index, 1364 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1365 bool inner_dim_reordered, int Alignment, int nr> 1366 struct gemm_pack_rhs< 1367 Scalar, Index, 1368 TensorContractionSubMapper< 1369 Scalar, Index, Rhs, 1370 TensorEvaluator<const TensorReshapingOp< 1371 NewDimension, const TensorVolumePatchOp< 1372 Planes, Rows, Cols, ArgType> >, 1373 Device>, 1374 nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, 1375 inner_dim_reordered, Alignment>, 1376 nr, ColMajor, false, false> { 1377 typedef TensorContractionSubMapper< 1378 Scalar, Index, Rhs, 1379 TensorEvaluator<const TensorReshapingOp< 1380 NewDimension, const TensorVolumePatchOp< 1381 Planes, Rows, Cols, ArgType> >, 1382 Device>, 1383 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1384 Alignment> 1385 SubMapper; 1386 typedef SubMapper DataMapper; 1387 1388 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1389 1390 EIGEN_DEVICE_FUNC 1391 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1392 Index depth, Index cols, Index stride = 0, 1393 Index offset = 0) const { 1394 eigen_assert(stride == 0); 1395 eigen_assert(offset == 0); 1396 1397 const Index packet_cols4 = (cols / 4) * 4; 1398 1399 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1400 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1401 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1402 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1403 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1404 1405 if (!rhs.nonStandardPatches()) { 1406 for (Index k = 0; k < depth; k++) { 1407 block[0] = dm0.loadCoeffStandard(k); 1408 block[1] = dm1.loadCoeffStandard(k); 1409 block[2] = dm2.loadCoeffStandard(k); 1410 block[3] = dm3.loadCoeffStandard(k); 1411 block += 4; 1412 } 1413 } else { 1414 for (Index k = 0; k < depth; k++) { 1415 block[0] = dm0(k); 1416 block[1] = dm1(k); 1417 block[2] = dm2(k); 1418 block[3] = dm3(k); 1419 block += 4; 1420 } 1421 } 1422 } 1423 1424 // Copy the remaining columns one at a time (nr==1). 1425 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1426 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1427 for (Index k = 0; k < depth; k++) { 1428 *block = dm0(k); 1429 block += 1; 1430 } 1431 } 1432 } 1433 }; 1434 1435 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1436 // Pack a block of the right input matrix (in our case it's always a "virtual 1437 // matrix" constructed from extracted image patches) in contiguous block in 1438 // column-major storage order. Knowing the properties of the original patch op 1439 // we can do it more efficient than the default gemm_pack_colmajor_block. 1440 // 1441 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports 1442 // squeezing reads along the 2 innermost dimensions, add it here if needed. 1443 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1444 typename ArgType, typename Device, typename Scalar, 1445 typename StorageIndex, typename nocontract_t, typename contract_t, 1446 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, 1447 int Alignment> 1448 struct gemm_pack_colmajor_block< 1449 Scalar, StorageIndex, 1450 TensorContractionSubMapper< 1451 Scalar, StorageIndex, Rhs, 1452 TensorEvaluator<const TensorReshapingOp< 1453 NewDimension, const TensorVolumePatchOp< 1454 Planes, Rows, Cols, ArgType> >, 1455 Device>, 1456 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1457 inner_dim_reordered, Alignment>, 1458 ColMajor> { 1459 typedef TensorContractionSubMapper< 1460 Scalar, StorageIndex, Rhs, 1461 TensorEvaluator<const TensorReshapingOp< 1462 NewDimension, const TensorVolumePatchOp< 1463 Planes, Rows, Cols, ArgType> >, 1464 Device>, 1465 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1466 inner_dim_reordered, Alignment> 1467 SubMapper; 1468 1469 typedef SubMapper DataMapper; 1470 typedef typename packet_traits<Scalar>::type Packet; 1471 1472 EIGEN_DONT_INLINE 1473 void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, 1474 StorageIndex cols) { 1475 const bool standard_patches = !rhs.nonStandardPatches(); 1476 1477 if (standard_patches && rhs.patchDepth() % packet_size == 0) { 1478 packStandardPatches<true>(block, rhs, rows, cols); 1479 1480 } else if (standard_patches) { 1481 packStandardPatches<false>(block, rhs, rows, cols); 1482 1483 } else { 1484 // With non-standard patches we don't do any vectorized loads. 1485 // TODO(ezhulenev): It doesn't look like that we should completely give up 1486 // on packets. Make this code path faster! 1487 for (StorageIndex col = 0; col < cols; ++col) { 1488 SubMapper lm = rhs.getLinearMapper(0, col); 1489 for (StorageIndex i = 0; i < rows; ++i) { 1490 *block = lm(i); 1491 ++block; 1492 } 1493 } 1494 } 1495 } 1496 1497 private: 1498 // Pack standard volume patches: 1499 // 1500 // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have 1501 // depth dimension size to be a multiple of packet size, so we can skip all 1502 // non vectorized loads and checks. 1503 // 1504 template <bool patch_depth_is_multiple_of_packet_size> 1505 EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block, 1506 const DataMapper& rhs, 1507 StorageIndex rows, 1508 StorageIndex cols) { 1509 eigen_assert(!rhs.nonStandardPatches()); 1510 1511 // Give vectorized_rows the name used in all other gemm_pack_rhs above. 1512 const Index peeled_k = (rows / packet_size) * packet_size; 1513 1514 const Index start_col = rhs.colOffset(); 1515 const Index max_col = rhs.maxCol(peeled_k); 1516 1517 for (StorageIndex col = 0; col < cols; ++col) { 1518 SubMapper lm = rhs.getLinearMapper(0, col); 1519 1520 Index k = 0; 1521 for (Index c = start_col; c < max_col; ++c) { 1522 eigen_assert(k <= peeled_k); 1523 1524 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1525 const Index max_row = rhs.maxRow(peeled_k, c); 1526 const bool pad_col = lm.padCol(c); 1527 1528 for (Index r = start_row; r < max_row; ++r) { 1529 eigen_assert(k <= peeled_k); 1530 1531 const Index start_plane = 1532 ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0; 1533 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1534 const bool pad_row = pad_col || lm.padRow(r); 1535 1536 for (Index p = start_plane; p < max_plane; ++p) { 1537 eigen_assert(k <= peeled_k); 1538 1539 const Index start_depth = 1540 ((c == start_col) && (r == start_row) && (p == start_plane)) 1541 ? rhs.depthOffset() 1542 : 0; 1543 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1544 1545 const bool pad = pad_col || pad_row || lm.padPlane(p); 1546 const Index base_idx = lm.baseIndex(p, r, c); 1547 1548 if (patch_depth_is_multiple_of_packet_size) 1549 eigen_assert((max_depth - start_depth) % packet_size == 0); 1550 1551 // If patch depth is a multiple of packet size, it's guaranteed that 1552 // we can process all values in depth dimension with packets. 1553 const Index max_vectorized_depth = 1554 patch_depth_is_multiple_of_packet_size 1555 ? max_depth 1556 : max_depth - packet_size; 1557 1558 Index d = start_depth; 1559 1560 // 1. Process depth dimension with vectorized instructions. 1561 for (; d < max_vectorized_depth; d += packet_size) { 1562 eigen_assert(k < peeled_k); 1563 const Packet packet = pad ? pset1<Packet>(Scalar(0)) 1564 : rhs.packetNoPadding(d, base_idx); 1565 internal::pstoreu(block, packet); 1566 block += packet_size; 1567 k += packet_size; 1568 } 1569 1570 // 2. Finish with coefficients. 1571 if (!patch_depth_is_multiple_of_packet_size) { 1572 for (; d < max_depth; d++) { 1573 eigen_assert(k < peeled_k); 1574 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx); 1575 ++block; 1576 ++k; 1577 } 1578 } 1579 } 1580 } 1581 } 1582 1583 // The loop above should fill peeled_k elements. 1584 eigen_assert(peeled_k == k); 1585 1586 // Fill remaining elements using loadCoeffStandard. 1587 for (; k < rows; ++k) { 1588 *block = lm.loadCoeffStandard(k); 1589 ++block; 1590 } 1591 } 1592 } 1593 }; 1594 #endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1595 1596 } // namespace internal 1597 1598 /** CuboidConvolution 1599 * \ingroup CXX11_NeuralNetworks_Module 1600 * 1601 * \brief Applies a 3D convolution over a multichannel input voxel block. 1602 * 1603 * The input parameter is expected to be a tensor with a rank of 4 or more 1604 * (channels, depth, height, width, and optionally others). 1605 * The kernel parameter is expected to be a 5D tensor (filters, channels, 1606 * kernel_depth, kernel_height, kernel_width). 1607 * The result can be assigned to a tensor of rank equal to the rank of the 1608 * input. The dimensions of the result will be filters, depth, height, width 1609 * (and others if applicable). 1610 * 1611 * The input and kernel have to be in the same layout, and both row-major and 1612 * col-major are supported. The shapes given above are for col-major layout. 1613 * For row-major, all dimensions should be reversed. 1614 * 1615 * It is possible to swap the order of the depth, width, and height dimensions 1616 * provided that the same order is used in the input, the kernel, and the 1617 * output. 1618 */ 1619 template <typename Input, typename Kernel> 1620 EIGEN_ALWAYS_INLINE static const typename internal::conditional< 1621 internal::traits<Input>::Layout == ColMajor, 1622 TensorReshapingOp< 1623 const DSizes<typename internal::traits<Input>::Index, 1624 internal::traits<Input>::NumDimensions>, 1625 const TensorContractionOp< 1626 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1627 const TensorReshapingOp< 1628 const DSizes<typename internal::traits<Input>::Index, 2>, 1629 const Kernel>, 1630 const TensorReshapingOp< 1631 const DSizes<typename internal::traits<Input>::Index, 2>, 1632 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1633 const Input> > > >, 1634 TensorReshapingOp< 1635 const DSizes<typename internal::traits<Input>::Index, 1636 internal::traits<Input>::NumDimensions>, 1637 const TensorContractionOp< 1638 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1639 const TensorReshapingOp< 1640 const DSizes<typename internal::traits<Input>::Index, 2>, 1641 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1642 const Input> >, 1643 const TensorReshapingOp< 1644 const DSizes<typename internal::traits<Input>::Index, 2>, 1645 const Kernel> > > >::type 1646 CuboidConvolution(const Input& input, const Kernel& kernel, 1647 const Index stridePlanes = 1, const Index strideRows = 1, 1648 const Index strideCols = 1, 1649 const PaddingType padding_type = PADDING_SAME) { 1650 typedef typename internal::traits<Input>::Index TensorIndex; 1651 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 1652 internal::traits<Input>::NumDimensions, 1653 internal::traits<Input>::Layout, TensorIndex> > 1654 in(input); 1655 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1656 internal::traits<Kernel>::NumDimensions, 1657 internal::traits<Kernel>::Layout, TensorIndex> > 1658 kern(kernel); 1659 1660 EIGEN_STATIC_ASSERT( 1661 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1662 YOU_MADE_A_PROGRAMMING_MISTAKE); 1663 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1664 static const int NumDims = internal::traits<Input>::NumDimensions; 1665 1666 // Number of filters to apply. This is the same as the output depth of the 1667 // result. 1668 const TensorIndex kernelFilters = 1669 isColMajor ? kern.dimensions()[0] : kern.dimensions()[4]; 1670 const TensorIndex kernelChannels = 1671 isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; 1672 1673 // Spatial size of the kernel. 1674 const TensorIndex kernelPlanes = 1675 isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; 1676 const TensorIndex kernelRows = 1677 isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; 1678 const TensorIndex kernelCols = 1679 isColMajor ? kern.dimensions()[4] : kern.dimensions()[0]; 1680 1681 if (isColMajor) { 1682 eigen_assert(kernelChannels == in.dimension(0)); 1683 } else { 1684 eigen_assert(kernelChannels == in.dimension(NumDims - 1)); 1685 } 1686 1687 const TensorIndex inputPlanes = 1688 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1689 const TensorIndex inputRows = 1690 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1691 const TensorIndex inputCols = 1692 isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); 1693 1694 TensorIndex out_planes; 1695 TensorIndex out_height; 1696 TensorIndex out_width; 1697 switch (padding_type) { 1698 case PADDING_VALID: 1699 out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1, 1700 static_cast<TensorIndex>(stridePlanes)); 1701 out_height = Eigen::divup(inputRows - kernelRows + 1, 1702 static_cast<TensorIndex>(strideRows)); 1703 out_width = Eigen::divup(inputCols - kernelCols + 1, 1704 static_cast<TensorIndex>(strideCols)); 1705 break; 1706 case PADDING_SAME: 1707 out_planes = 1708 Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); 1709 out_height = 1710 Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); 1711 out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); 1712 break; 1713 default: 1714 out_planes = 0; 1715 out_height = 0; 1716 out_width = 0; 1717 eigen_assert(false && "unexpected padding"); 1718 } 1719 1720 DSizes<TensorIndex, 2> kernel_dims; 1721 if (isColMajor) { 1722 kernel_dims[0] = kernelFilters; 1723 kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1724 } else { 1725 kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1726 kernel_dims[1] = kernelFilters; 1727 } 1728 1729 // Molds the output of the patch extraction result into a 2D tensor: 1730 // - the first dimension (dims[0]): the patch values to be multiplied with the 1731 // kernels 1732 // - the second dimension (dims[1]): everything else 1733 DSizes<TensorIndex, 2> pre_contract_dims; 1734 if (isColMajor) { 1735 pre_contract_dims[0] = 1736 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1737 pre_contract_dims[1] = out_planes * out_height * out_width; 1738 for (int i = 4; i < NumDims; ++i) { 1739 pre_contract_dims[1] *= in.dimension(i); 1740 } 1741 } else { 1742 pre_contract_dims[1] = 1743 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1744 pre_contract_dims[0] = out_planes * out_height * out_width; 1745 for (int i = 0; i < NumDims - 4; ++i) { 1746 pre_contract_dims[0] *= in.dimension(i); 1747 } 1748 } 1749 1750 array<IndexPair<TensorIndex>, 1> contract_dims; 1751 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1752 1753 // Molds the output of the contraction into the shape expected by the user 1754 // (assuming ColMajor): 1755 // - 1st dim: kernel filters 1756 // - 2nd dim: output depth 1757 // - 3nd dim: output height 1758 // - 4rd dim: output width 1759 // - 5th dim and beyond: everything else including batch size 1760 DSizes<TensorIndex, NumDims> post_contract_dims; 1761 if (isColMajor) { 1762 post_contract_dims[0] = kernelFilters; 1763 post_contract_dims[1] = out_planes; 1764 post_contract_dims[2] = out_height; 1765 post_contract_dims[3] = out_width; 1766 for (int i = 4; i < NumDims; ++i) { 1767 post_contract_dims[i] = in.dimension(i); 1768 } 1769 } else { 1770 post_contract_dims[NumDims - 1] = kernelFilters; 1771 post_contract_dims[NumDims - 2] = out_planes; 1772 post_contract_dims[NumDims - 3] = out_height; 1773 post_contract_dims[NumDims - 4] = out_width; 1774 for (int i = 0; i < NumDims - 4; ++i) { 1775 post_contract_dims[i] = in.dimension(i); 1776 } 1777 } 1778 1779 return choose( 1780 Cond<internal::traits<Input>::Layout == ColMajor>(), 1781 kernel.reshape(kernel_dims) 1782 .contract(input 1783 .extract_volume_patches( 1784 kernelPlanes, kernelRows, kernelCols, stridePlanes, 1785 strideRows, strideCols, padding_type) 1786 .reshape(pre_contract_dims), 1787 contract_dims) 1788 .reshape(post_contract_dims), 1789 input 1790 .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1791 stridePlanes, strideRows, strideCols, 1792 padding_type) 1793 .reshape(pre_contract_dims) 1794 .contract(kernel.reshape(kernel_dims), contract_dims) 1795 .reshape(post_contract_dims)); 1796 } 1797 1798 } // end namespace Eigen 1799 1800 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 1801