1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 18 19 // Note this header is used in both TF and TFLite. 20 namespace Eigen { 21 22 namespace internal { 23 24 // WARNING: Most of the code here implicitly assumes that the matrix is in 25 // ColMajor layout. This is guaranteed by the tensor contraction (see 26 // TensorContraction.h). 27 // 28 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 29 // We don't want to actually extract image patches and reshape the result into 30 // a matrix (this involves allocating huge extra memory), so the patch 31 // extraction and reshape operations are implicit. 32 // 33 // TensorContractionInputMapper takes a matrix index and returns the coefficient 34 // (or the packet) of the "virtual tensor", that would be at that index if we 35 // were to actually reshape the result of patch extraction. 36 // 37 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 38 // at the given vertical and horizontal offsets. 39 // 40 // "Virtual matrix" dimensions: 41 // *0: kernelChannels * kernelRows * kernelCols; 42 // 1: out_height * out_width; * OTHERS (e.g batches, etc...) 43 // 44 // *) extracted patches are continuous in memory (innermost dimension assuming 45 // col major layout) 46 // 47 // With this dimensions: 48 // row - offset within a single patch (in code: patchId) 49 // col - index of the extracted patch (in code: patchIndex) 50 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 51 // 52 // TODO(ezhulenev): Consolidate this part of the code with the image patch 53 // extraction code since they are both very similar. 54 55 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 56 typename Device, typename Scalar_, typename Index, 57 typename nocontract_t, typename contract_t, int Side, int packet_size, 58 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 59 class TensorContractionInputMapper< 60 Scalar_, Index, Side, 61 TensorEvaluator< 62 const TensorReshapingOp<NewDimension, 63 const TensorImagePatchOp<Rows, Cols, ArgType> >, 64 Device>, 65 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 66 inner_dim_reordered, Alignment> { 67 public: 68 typedef Scalar_ Scalar; 69 70 typedef TensorContractionInputMapper< 71 Scalar, Index, Side, 72 TensorEvaluator< 73 const TensorReshapingOp< 74 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 75 Device>, 76 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 77 inner_dim_reordered, Alignment> 78 Self; 79 80 typedef TensorContractionSubMapper< 81 Scalar, Index, Side, 82 TensorEvaluator< 83 const TensorReshapingOp< 84 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 85 Device>, 86 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 87 inner_dim_reordered, Alignment> 88 SubMapper; 89 90 typedef SubMapper VectorMapper; 91 typedef SubMapper LinearMapper; 92 typedef typename packet_traits<Scalar>::type Packet; 93 94 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)95 TensorContractionInputMapper( 96 const TensorEvaluator< 97 const TensorReshapingOp< 98 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 99 Device>& tensor, 100 const nocontract_t&, const nocontract_t&, const contract_t&, 101 const contract_t&) 102 : m_impl(tensor.impl().impl()) { 103 Index patch_rows; 104 Index patch_depth; 105 if (internal::traits<ArgType>::Layout == ColMajor) { 106 patch_depth = tensor.impl().dimensions()[0]; 107 patch_rows = tensor.impl().dimensions()[1]; 108 m_patch_cols = tensor.impl().dimensions()[2]; 109 m_num_patches = tensor.impl().dimensions()[3]; 110 } else { 111 const size_t NumDims = tensor.impl().dimensions().size(); 112 patch_depth = tensor.impl().dimensions()[NumDims - 1]; 113 patch_rows = tensor.impl().dimensions()[NumDims - 2]; 114 m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; 115 m_num_patches = tensor.impl().dimensions()[NumDims - 4]; 116 } 117 118 // Strides for navigating through the single patch. 119 m_patch_row_stride = patch_depth; 120 m_patch_col_stride = patch_rows * m_patch_row_stride; 121 122 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 123 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 124 125 m_colStride = patch_rows; 126 127 m_outputRows = tensor.impl().outputRows(); 128 m_row_strides = tensor.impl().userRowStride(); 129 m_col_strides = tensor.impl().userColStride(); 130 131 m_in_row_strides = tensor.impl().userInRowStride(); 132 m_in_col_strides = tensor.impl().userInColStride(); 133 134 if (internal::traits<ArgType>::Layout == ColMajor) { 135 m_inputRows = tensor.impl().impl().dimensions()[1]; 136 m_inputCols = tensor.impl().impl().dimensions()[2]; 137 } else { 138 const int NumDims = tensor.impl().impl().dimensions().size(); 139 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; 140 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; 141 } 142 143 m_rowInputStride = patch_depth; 144 m_colInputStride = patch_depth * m_inputRows; 145 m_patchInputStride = patch_depth * m_inputRows * m_inputCols; 146 147 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 148 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 149 150 m_fastPatchRowStride = 151 internal::TensorIntDivisor<Index>(m_patch_row_stride); 152 m_fastPatchColStride = 153 internal::TensorIntDivisor<Index>(m_patch_col_stride); 154 m_fastInputRowStride = 155 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 156 m_fastInputColStride = 157 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 158 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 159 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 160 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 161 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); 162 } 163 164 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)165 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 166 : m_impl(base_mapper.m_impl) { 167 m_patch_cols = base_mapper.m_patch_cols; 168 m_num_patches = base_mapper.m_num_patches; 169 170 m_patch_row_stride = base_mapper.m_patch_row_stride; 171 m_patch_col_stride = base_mapper.m_patch_col_stride; 172 173 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 174 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 175 176 m_colStride = base_mapper.m_colStride; 177 178 m_rowInputStride = base_mapper.m_rowInputStride; 179 m_colInputStride = base_mapper.m_colInputStride; 180 m_patchInputStride = base_mapper.m_patchInputStride; 181 182 m_inputRows = base_mapper.m_inputRows; 183 m_inputCols = base_mapper.m_inputCols; 184 185 m_outputRows = base_mapper.m_outputRows; 186 m_row_strides = base_mapper.m_row_strides; 187 m_col_strides = base_mapper.m_col_strides; 188 189 m_in_row_strides = base_mapper.m_in_row_strides; 190 m_in_col_strides = base_mapper.m_in_col_strides; 191 192 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 193 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 194 195 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 196 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 197 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 198 m_fastInputColStride = base_mapper.m_fastInputColStride; 199 m_fastNumPatches = base_mapper.m_fastNumPatches; 200 m_fastColStride = base_mapper.m_fastColStride; 201 m_fastOutputRows = base_mapper.m_fastOutputRows; 202 m_fastDimZero = base_mapper.m_fastDimZero; 203 } 204 205 // If true, turns off some optimizations for loading packets since the image 206 // patches are "non-standard" such as there are non-trivial strides or 207 // inflations in the input. 208 EIGEN_DEVICE_FUNC nonStandardPatches()209 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 210 return m_in_row_strides != 1 || m_in_col_strides != 1 || 211 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 212 } 213 214 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)215 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 216 return SubMapper(*this, i, j); 217 } 218 219 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)220 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 221 return LinearMapper(*this, i, j); 222 } 223 224 EIGEN_DEVICE_FUNC operator()225 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 226 Index rowIndex, colIndex, otherIndex; 227 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 228 return loadCoeff(row, rowIndex, colIndex, otherIndex); 229 } 230 231 // Load the coefficient at the patchIndex location instead of the usual 232 // m_rowIndex, 233 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 234 // EIGEN_DEVICE_FUNC 235 EIGEN_DEVICE_FUNC operator()236 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 237 Index rowIndex, colIndex, otherIndex; 238 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 239 return loadCoeff(row, rowIndex, colIndex, otherIndex); 240 } 241 242 EIGEN_DEVICE_FUNC loadPacket(Index row)243 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 244 Index rowIndex, colIndex, otherIndex; 245 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 246 return loadPacket(row, rowIndex, colIndex, otherIndex); 247 } 248 249 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 250 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 251 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)252 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 253 Index rowIndex, colIndex, otherIndex; 254 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 255 return loadPacket(row, rowIndex, colIndex, otherIndex); 256 } 257 258 EIGEN_DEVICE_FUNC impl()259 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 260 return m_impl; 261 } 262 263 EIGEN_DEVICE_FUNC patchDepth()264 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } 265 EIGEN_DEVICE_FUNC patchRows()266 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } 267 EIGEN_DEVICE_FUNC patchCols()268 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 269 270 private: 271 friend class TensorContractionSubMapper< 272 Scalar, Index, Side, 273 TensorEvaluator< 274 const TensorReshapingOp< 275 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 276 Device>, 277 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 278 inner_dim_reordered, Alignment>; 279 280 // Load coefficient from a patch specified by the "within patch offset" 281 // (patchId) and the precomputed indices of the first element of the patch. 282 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)283 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, 284 Index colIndex, Index otherIndex) const { 285 // Find the offset of the element wrt the location of the first element. 286 const Index patchOffset = patchId / m_fastDimZero; 287 288 const Index colOffset = patchOffset / m_fastColStride; 289 const Index inputCol = colIndex + colOffset * m_in_col_strides; 290 const Index origInputCol = 291 (m_patch_col_inflate_strides == 1) 292 ? inputCol 293 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 294 295 const Index rowOffset = patchOffset - colOffset * m_colStride; 296 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 297 const Index origInputRow = 298 (m_patch_row_inflate_strides == 1) 299 ? inputRow 300 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 301 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || 302 origInputRow >= m_inputRows || 303 (inputCol != origInputCol * m_patch_col_inflate_strides) || 304 (inputRow != origInputRow * m_patch_row_inflate_strides)) { 305 return Scalar(0); 306 } 307 const Index depth = patchId - patchOffset * patchDepth(); 308 const Index inputIndex = depth + origInputRow * m_rowInputStride + 309 origInputCol * m_colInputStride + otherIndex; 310 return m_impl.coeff(inputIndex); 311 } 312 313 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 314 // and `in_strides` equal to 1 (template specialization without templates). 315 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)316 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, 317 Index colIndex, 318 Index otherIndex) const { 319 eigen_assert(!nonStandardPatches()); 320 321 // Find the offset of the element wrt the location of the first element. 322 const Index patchOffset = patchId / m_fastDimZero; 323 const Index colOffset = patchOffset / m_fastColStride; 324 const Index rowOffset = patchOffset - colOffset * m_colStride; 325 const Index inputCol = colIndex + colOffset; 326 const Index inputRow = rowIndex + rowOffset; 327 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 328 inputRow >= m_inputRows) { 329 return Scalar(0); 330 } 331 const Index depth = patchId - patchOffset * patchDepth(); 332 const Index inputIndex = depth + inputRow * m_rowInputStride + 333 inputCol * m_colInputStride + otherIndex; 334 return m_impl.coeff(inputIndex); 335 } 336 337 // Load packet from a patch specified by the "within patch offset" 338 // (patchId) and the precomputed indices of the first element of the patch. 339 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)340 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, 341 Index colIndex, 342 Index otherIndex) const { 343 const Index packetSize = internal::unpacket_traits<Packet>::size; 344 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 345 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 346 347 if (nonStandardPatches()) { 348 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 349 } 350 return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex); 351 } 352 353 EIGEN_DEVICE_FUNC loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)354 EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex, 355 Index colIndex, 356 Index otherIndex) const { 357 const Index packetSize = internal::unpacket_traits<Packet>::size; 358 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 359 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 360 361 eigen_assert(!nonStandardPatches()); 362 363 if ((patchDepth() % packetSize) == 0) { 364 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 365 } else { 366 // Offsets and input calculation here are identical to 367 // loadCoeffStandard(...), but repeated twice. 368 369 const Index patchOffsets[2] = { 370 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 371 372 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 373 patchOffsets[1] / m_fastColStride}; 374 const Index inputCols[2] = {colIndex + colOffsets[0], 375 colIndex + colOffsets[1]}; 376 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 377 // all zeros 378 return internal::pset1<Packet>(Scalar(0)); 379 } 380 381 if (inputCols[0] == inputCols[1]) { 382 const Index rowOffsets[2] = { 383 patchOffsets[0] - colOffsets[0] * m_colStride, 384 patchOffsets[1] - colOffsets[1] * m_colStride}; 385 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 386 const Index inputRows[2] = {rowIndex + rowOffsets[0], 387 rowIndex + rowOffsets[1]}; 388 389 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 390 // all zeros 391 return internal::pset1<Packet>(Scalar(0)); 392 } 393 394 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 395 // no padding 396 const Index depth = patchId - patchOffsets[0] * patchDepth(); 397 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 398 inputCols[0] * m_colInputStride + otherIndex; 399 return m_impl.template packet<Unaligned>(inputIndex); 400 } 401 } 402 } 403 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 404 } 405 406 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)407 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, 408 Index colIndex, 409 Index otherIndex) const { 410 const Index packetSize = internal::unpacket_traits<Packet>::size; 411 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 412 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 413 414 eigen_assert(!nonStandardPatches()); 415 eigen_assert((patchDepth() % packetSize) == 0); 416 // Find the offset of the element wrt the location of the first element. 417 const Index patchOffset = patchId / m_fastDimZero; 418 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 419 420 const Index colOffset = patchOffset / m_fastColStride; 421 const Index rowOffset = patchOffset - colOffset * m_colStride; 422 const Index inputCol = colIndex + colOffset; 423 const Index inputRow = rowIndex + rowOffset; 424 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || 425 inputRow >= m_inputRows) { 426 // all zeros 427 return internal::pset1<Packet>(Scalar(0)); 428 } 429 // no padding 430 const Index depth = patchId - patchOffset * patchDepth(); 431 const Index inputIndex = depth + inputRow * m_rowInputStride + 432 inputCol * m_colInputStride + otherIndex; 433 return m_impl.template packet<Unaligned>(inputIndex); 434 } 435 packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)436 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( 437 Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { 438 const int packetSize = internal::unpacket_traits<Packet>::size; 439 EIGEN_ALIGN_MAX 440 typename internal::remove_const<Scalar>::type values[packetSize]; 441 for (int i = 0; i < packetSize; ++i) { 442 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); 443 } 444 Packet rslt = internal::pload<Packet>(values); 445 return rslt; 446 } 447 computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)448 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 449 Index patchIndex, Index& rowIndex, Index& colIndex, 450 Index& otherIndex) const { 451 const size_t NumInputDims = array_size< 452 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 453 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; 454 const Index patch2DIndex = (NumInputDims == 3) 455 ? patchIndex 456 : (patchIndex - otherIndex * m_num_patches); 457 otherIndex *= m_patchInputStride; 458 colIndex = patch2DIndex / m_fastOutputRows; 459 rowIndex = patch2DIndex - colIndex * m_outputRows; 460 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 461 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 462 } 463 464 Index m_patch_cols; // number of columns in the patch 465 Index m_num_patches; // number of patches to extract. 466 467 // Strides for navigating through the single patch. 468 Index m_patch_row_stride; 469 Index m_patch_col_stride; 470 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 471 internal::TensorIntDivisor<Index> m_fastPatchColStride; 472 473 Index m_patch_row_inflate_strides; // the strides for row inflation in the 474 // image patch 475 Index m_patch_col_inflate_strides; // the strides for col inflation in the 476 // image patch 477 // Fast representation of inflation strides. 478 internal::TensorIntDivisor<Index> m_fastInputRowStride; 479 internal::TensorIntDivisor<Index> m_fastInputColStride; 480 481 Index m_otherStride; 482 Index m_colStride; 483 internal::TensorIntDivisor<Index> m_fastNumPatches; 484 internal::TensorIntDivisor<Index> m_fastColStride; 485 486 Index m_rowInputStride; // row stride in the input tensor 487 Index m_colInputStride; // col stride in the input tensor 488 Index m_patchInputStride; // patch stride in the input tensor 489 490 Index m_inputRows; // Number of rows in the input tensor 491 Index m_inputCols; // Number of cols in the input tensor 492 493 Index m_outputRows; // Number of patch rows 494 495 Index m_row_strides; // User specified row stride 496 Index m_col_strides; // User specified col stride 497 498 Index m_in_row_strides; // User specified input row stride 499 Index m_in_col_strides; // User specified input col stride 500 501 Index m_rowPaddingTop; // Row padding 502 Index m_colPaddingLeft; // Column padding 503 504 internal::TensorIntDivisor<Index> m_fastOutputRows; 505 internal::TensorIntDivisor<Index> m_fastDimZero; 506 507 const TensorEvaluator<ArgType, Device> m_impl; 508 }; 509 510 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 511 typename Device, typename Scalar, typename Index, 512 typename nocontract_t, typename contract_t, int Side, int packet_size, 513 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 514 class TensorContractionSubMapper< 515 Scalar, Index, Side, 516 TensorEvaluator< 517 const TensorReshapingOp<NewDimension, 518 const TensorImagePatchOp<Rows, Cols, ArgType> >, 519 Device>, 520 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 521 inner_dim_reordered, Alignment> { 522 public: 523 typedef typename packet_traits<Scalar>::type Packet; 524 typedef typename packet_traits<Scalar>::half HalfPacket; 525 526 typedef TensorContractionInputMapper< 527 Scalar, Index, Side, 528 TensorEvaluator< 529 const TensorReshapingOp< 530 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 531 Device>, 532 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 533 inner_dim_reordered, Alignment> 534 ParentMapper; 535 536 typedef TensorContractionSubMapper< 537 Scalar, Index, Side, 538 TensorEvaluator< 539 const TensorReshapingOp< 540 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 541 Device>, 542 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 543 inner_dim_reordered, Alignment> 544 Self; 545 546 typedef Self LinearMapper; 547 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)548 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 549 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 550 : m_depth_offset(vert_offset), 551 m_col_offset(horiz_offset), 552 m_base_mapper(base_mapper) { 553 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 554 m_otherIndex); 555 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)556 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 557 const Self& base_mapper, Index vert_offset, Index horiz_offset) 558 : m_depth_offset(vert_offset + base_mapper.m_depth_offset), 559 m_col_offset(horiz_offset + base_mapper.m_col_offset), 560 m_base_mapper(base_mapper.m_base_mapper) { 561 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 562 m_otherIndex); 563 } operator()564 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 565 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, 566 m_otherIndex); 567 } operator()568 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 569 Index j) const { 570 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 571 } 572 loadPacket(Index i)573 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 574 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, 575 m_otherIndex); 576 } loadPacket(Index i,Index j)577 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 578 Index j) const { 579 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 580 j + m_col_offset); 581 } 582 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)583 loadCoeffStandard(Index i) const { 584 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, 585 m_colIndex, m_otherIndex); 586 } 587 loadPacketFast(Index i)588 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 589 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, 590 m_colIndex, m_otherIndex); 591 } 592 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)593 loadPacketStandard(Index i) const { 594 return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex, 595 m_colIndex, m_otherIndex); 596 } 597 template <typename Packet> aligned(Index)598 EIGEN_DEVICE_FUNC bool aligned(Index) const { 599 return false; 600 } 601 602 EIGEN_DEVICE_FUNC nonStandardPatches()603 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 604 return m_base_mapper.nonStandardPatches(); 605 } 606 607 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth 608 // index respectively that fits into the peeled_k elements starting at 609 // m_depth_offset. 610 611 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)612 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 613 const Index max_col = 614 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / 615 fastPatchColStride(); 616 return std::min<Index>(1 + max_col, patchCols()); 617 } 618 619 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)620 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 621 const Index col) const { 622 const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - 623 col * patchColStride()) / 624 fastPatchRowStride(); 625 return std::min<Index>(1 + max_row, patchRows()); 626 } 627 628 EIGEN_DEVICE_FUNC maxDepth(const Index peeled_k,const Index col,Index row)629 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, 630 Index row) const { 631 const Index max_depth = m_depth_offset + peeled_k - // 632 col * patchColStride() - // 633 row * patchRowStride(); 634 return std::min<Index>(max_depth, patchDepth()); 635 } 636 637 // MaxDepth uses only the remaining number of elements in the peeled_k. 638 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)639 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 640 const Index start_depth) const { 641 return std::min<Index>(start_depth + num_elements, patchDepth()); 642 } 643 644 // Every register matters in this code, so sometimes to prevent register 645 // spilling, instead of the variable that you would expect to see, we use 646 // another one, that is guaranteed to have the same value. E.g. patch depth is 647 // always the same as input depth, and it's also the same as input row stride. 648 // Bunch of other parameters have similar relations. 649 650 typedef internal::TensorIntDivisor<Index> IndexDivisor; 651 652 EIGEN_DEVICE_FUNC patchDepth()653 EIGEN_ALWAYS_INLINE Index patchDepth() const { 654 return m_base_mapper.m_rowInputStride; 655 } 656 EIGEN_DEVICE_FUNC patchRows()657 EIGEN_ALWAYS_INLINE Index patchRows() const { 658 return m_base_mapper.m_colStride; 659 } 660 EIGEN_DEVICE_FUNC patchCols()661 EIGEN_ALWAYS_INLINE Index patchCols() const { 662 return m_base_mapper.m_patch_cols; 663 } 664 665 EIGEN_DEVICE_FUNC patchRowStride()666 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 667 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 668 "Patch depth must be equal to patch row stride."); 669 return patchDepth(); 670 } 671 EIGEN_DEVICE_FUNC patchColStride()672 EIGEN_ALWAYS_INLINE Index patchColStride() const { 673 return m_base_mapper.m_patch_col_stride; 674 } 675 676 EIGEN_DEVICE_FUNC fastPatchRowStride()677 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 678 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 679 "Patch depth must be equal to patch row stride."); 680 return m_base_mapper.m_fastDimZero; // patch_depth 681 } 682 EIGEN_DEVICE_FUNC fastPatchColStride()683 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 684 return m_base_mapper.m_fastPatchColStride; 685 } 686 687 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)688 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 689 const Index baseIndex) const { 690 const Index inputIndex = depth + baseIndex; 691 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 692 } 693 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)694 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 695 const Index baseIndex) const { 696 const Index inputIndex = depth + baseIndex; 697 return m_base_mapper.m_impl.coeff(inputIndex); 698 } 699 700 EIGEN_DEVICE_FUNC padRow(const Index row)701 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 702 const Index r = m_rowIndex + row; 703 return r < 0 || r >= m_base_mapper.m_inputRows; 704 } 705 EIGEN_DEVICE_FUNC padAnyRow(const Index first_row,const Index last_row)706 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, 707 const Index last_row) const { 708 return m_rowIndex + first_row < 0 || 709 m_rowIndex + last_row >= m_base_mapper.m_inputRows; 710 } 711 EIGEN_DEVICE_FUNC padCol(const Index col)712 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 713 const Index c = m_colIndex + col; 714 return c < 0 || c >= m_base_mapper.m_inputCols; 715 } 716 EIGEN_DEVICE_FUNC baseIndex(const Index row,const Index col)717 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { 718 const Index r = m_rowIndex + row; 719 const Index c = m_colIndex + col; 720 return r * m_base_mapper.m_rowInputStride + 721 c * m_base_mapper.m_colInputStride + m_otherIndex; 722 } 723 724 EIGEN_DEVICE_FUNC rowStride()725 EIGEN_ALWAYS_INLINE Index rowStride() const { 726 return m_base_mapper.m_row_strides; 727 } 728 EIGEN_DEVICE_FUNC colStride()729 EIGEN_ALWAYS_INLINE Index colStride() const { 730 return m_base_mapper.m_col_strides; 731 } 732 733 EIGEN_DEVICE_FUNC rowOffset()734 EIGEN_ALWAYS_INLINE Index rowOffset() const { 735 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 736 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 737 return patchOffset - colOffset * m_base_mapper.m_colStride; 738 } 739 740 EIGEN_DEVICE_FUNC colOffset()741 EIGEN_ALWAYS_INLINE Index colOffset() const { 742 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 743 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 744 return colOffset; 745 } 746 747 EIGEN_DEVICE_FUNC depthOffset()748 EIGEN_ALWAYS_INLINE Index depthOffset() const { 749 return m_depth_offset % patchDepth(); 750 } 751 752 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)753 getLinearMapper(Index i, Index j) const { 754 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 755 } 756 757 private: 758 Index m_depth_offset; // First row in the input matrix 759 Index m_col_offset; // First col in the input matrix 760 761 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 762 // indices for the first element in a patch specified by col_offset 763 // (see computeBaseIndices(...) for details). 764 Index m_rowIndex; 765 Index m_colIndex; 766 Index m_otherIndex; 767 768 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 769 // performs better in benchmarks. 770 }; 771 772 // Arrange a block of the right input matrix (in our case it's always a "virtual 773 // matrix" constructed from extracted image patches) in contiguous memory. 774 // 775 // Given column major input (A0 beside A1 in memory): 776 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 777 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 778 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 779 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 780 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 781 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 782 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 783 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 784 // A8 ... 785 // ... 786 // 787 // *) A, B, C, ... - patches extracted from the original input. 788 // *) A0, A1, A2 ... - values from the same patch at different offsets. 789 // 790 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 791 // A0 B0 C0 D0 A1 B1 C1 D1 ... 792 // E0 F0 G0 H0 E1 F1 G1 H1 ... 793 // ... 794 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 795 // 796 // This traversal order must be the same as in default gemm_pack_rhs defined in 797 // GeneralBlockPanelKernel.h. 798 // 799 // *) nr - number of registers along the 'n' dimension. 800 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 801 // Multiplication" paper. 802 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 803 typename Device, typename Scalar, typename Index, 804 typename nocontract_t, typename contract_t, int packet_size, 805 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 806 int nr> 807 struct gemm_pack_rhs< 808 Scalar, Index, 809 TensorContractionSubMapper< 810 Scalar, Index, Rhs, 811 TensorEvaluator< 812 const TensorReshapingOp< 813 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 814 Device>, 815 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 816 inner_dim_reordered, Alignment>, 817 nr, ColMajor, false, false> { 818 typedef TensorContractionSubMapper< 819 Scalar, Index, Rhs, 820 TensorEvaluator< 821 const TensorReshapingOp< 822 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 823 Device>, 824 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 825 inner_dim_reordered, Alignment> 826 SubMapper; 827 typedef SubMapper DataMapper; 828 typedef typename packet_traits<Scalar>::type Packet; 829 830 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 831 832 EIGEN_DEVICE_FUNC 833 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 834 Index depth, Index cols, Index stride = 0, 835 Index offset = 0) const { 836 eigen_assert(stride == 0); 837 eigen_assert(offset == 0); 838 839 const Index packet_cols4 = (cols / 4) * 4; 840 const Index peeled_k = (depth / packet_size) * packet_size; 841 const bool non_standard_patches = rhs.nonStandardPatches(); 842 843 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 844 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 845 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 846 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 847 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 848 849 Index k = 0; 850 if ((packet_size % 4) == 0 && !non_standard_patches) { 851 // FAST PATH: 852 // Iterate over patch columns and rows, if we know that a single 853 // packet do not span across multiple rows or columns. 854 if ((rhs.patchDepth() % packet_size) == 0) { 855 const Index start_col = rhs.colOffset(); 856 const Index max_col = rhs.maxCol(peeled_k); 857 858 for (Index c = start_col; c < max_col; ++c) { 859 eigen_assert(k <= peeled_k); 860 861 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 862 const Index max_row = rhs.maxRow(peeled_k, c); 863 864 const bool pad_col0 = dm0.padCol(c); 865 const bool pad_col1 = dm1.padCol(c); 866 const bool pad_col2 = dm2.padCol(c); 867 const bool pad_col3 = dm3.padCol(c); 868 869 // Check if we can squeeze reads along the `row` and `depth` 870 // dimensions (two innermost dimensions). 871 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 872 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 873 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 874 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 875 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 876 // Compute how many elements we can squeeze read. 877 const Index start_depth = 878 (c == start_col) ? rhs.depthOffset() : 0; 879 880 // Upper bound for the number of elements in the depth dimension 881 // that we can squeeze read. 882 const Index squeeze_length = 883 (max_row - start_row) * rhs.patchDepth() - start_depth; 884 885 // Do not overshoot beyond the block size. 886 const Index max_depth = 887 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 888 eigen_assert((max_depth - start_depth) % packet_size == 0); 889 890 const Index idx0 = dm0.baseIndex(start_row, c); 891 const Index idx1 = dm1.baseIndex(start_row, c); 892 const Index idx2 = dm2.baseIndex(start_row, c); 893 const Index idx3 = dm3.baseIndex(start_row, c); 894 895 for (Index d = start_depth; d < max_depth; d += packet_size) { 896 eigen_assert(k < peeled_k); 897 PacketBlock<Packet, 4> kernel; 898 kernel.packet[0] = rhs.packetNoPadding(d, idx0); 899 kernel.packet[1] = rhs.packetNoPadding(d, idx1); 900 kernel.packet[2] = rhs.packetNoPadding(d, idx2); 901 kernel.packet[3] = rhs.packetNoPadding(d, idx3); 902 ptranspose(kernel); 903 pstoreu(block + 0 * packet_size, kernel.packet[0]); 904 pstoreu(block + 1 * packet_size, kernel.packet[1]); 905 pstoreu(block + 2 * packet_size, kernel.packet[2]); 906 pstoreu(block + 3 * packet_size, kernel.packet[3]); 907 block += 4 * packet_size; 908 k += packet_size; 909 } 910 911 // Go to the next column. 912 continue; 913 } 914 915 // If we can't squeeze reads, process rows one by one. 916 for (Index r = start_row; r < max_row; ++r) { 917 eigen_assert(k <= peeled_k); 918 919 const bool pad0 = pad_col0 || dm0.padRow(r); 920 const bool pad1 = pad_col1 || dm1.padRow(r); 921 const bool pad2 = pad_col2 || dm2.padRow(r); 922 const bool pad3 = pad_col3 || dm3.padRow(r); 923 924 const Index idx0 = dm0.baseIndex(r, c); 925 const Index idx1 = dm1.baseIndex(r, c); 926 const Index idx2 = dm2.baseIndex(r, c); 927 const Index idx3 = dm3.baseIndex(r, c); 928 929 const Index start_depth = ((c == start_col) && (r == start_row)) 930 ? rhs.depthOffset() 931 : 0; 932 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 933 eigen_assert((max_depth - start_depth) % packet_size == 0); 934 935 for (Index d = start_depth; d < max_depth; d += packet_size) { 936 eigen_assert(k < peeled_k); 937 PacketBlock<Packet, 4> kernel; 938 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 939 : rhs.packetNoPadding(d, idx0); 940 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 941 : rhs.packetNoPadding(d, idx1); 942 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 943 : rhs.packetNoPadding(d, idx2); 944 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 945 : rhs.packetNoPadding(d, idx3); 946 ptranspose(kernel); 947 pstoreu(block + 0 * packet_size, kernel.packet[0]); 948 pstoreu(block + 1 * packet_size, kernel.packet[1]); 949 pstoreu(block + 2 * packet_size, kernel.packet[2]); 950 pstoreu(block + 3 * packet_size, kernel.packet[3]); 951 block += 4 * packet_size; 952 k += packet_size; 953 } 954 } 955 } 956 957 // The loop above should fill peeled_k elements. 958 eigen_assert(peeled_k == k); 959 960 } else { 961 for (; k < peeled_k; k += packet_size) { 962 PacketBlock<Packet, 4> kernel; 963 kernel.packet[0] = dm0.loadPacketStandard(k); 964 kernel.packet[1] = dm1.loadPacketStandard(k); 965 kernel.packet[2] = dm2.loadPacketStandard(k); 966 kernel.packet[3] = dm3.loadPacketStandard(k); 967 ptranspose(kernel); 968 pstoreu(block + 0 * packet_size, kernel.packet[0]); 969 pstoreu(block + 1 * packet_size, kernel.packet[1]); 970 pstoreu(block + 2 * packet_size, kernel.packet[2]); 971 pstoreu(block + 3 * packet_size, kernel.packet[3]); 972 block += 4 * packet_size; 973 } 974 } 975 } 976 977 // Copy the remaining coefficients of the column block after the peeled_k. 978 if (!rhs.nonStandardPatches()) { 979 for (; k < depth; k++) { 980 block[0] = dm0.loadCoeffStandard(k); 981 block[1] = dm1.loadCoeffStandard(k); 982 block[2] = dm2.loadCoeffStandard(k); 983 block[3] = dm3.loadCoeffStandard(k); 984 block += 4; 985 } 986 } else { 987 for (; k < depth; k++) { 988 block[0] = dm0(k); 989 block[1] = dm1(k); 990 block[2] = dm2(k); 991 block[3] = dm3(k); 992 block += 4; 993 } 994 } 995 } 996 997 // copy the remaining columns one at a time (nr==1) 998 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 999 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1000 for (Index k = 0; k < depth; k++) { 1001 *block = dm0(k); 1002 block += 1; 1003 } 1004 } 1005 } 1006 }; 1007 1008 // Template specialization for packet_size = 2. We must special-case packet 1009 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1010 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1011 typename Device, typename Scalar, typename Index, 1012 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1013 bool inner_dim_reordered, int Alignment, int nr> 1014 struct gemm_pack_rhs< 1015 Scalar, Index, 1016 TensorContractionSubMapper< 1017 Scalar, Index, Rhs, 1018 TensorEvaluator< 1019 const TensorReshapingOp< 1020 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1021 Device>, 1022 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1023 Alignment>, 1024 nr, ColMajor, false, false> { 1025 typedef TensorContractionSubMapper< 1026 Scalar, Index, Rhs, 1027 TensorEvaluator< 1028 const TensorReshapingOp< 1029 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1030 Device>, 1031 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1032 Alignment> 1033 SubMapper; 1034 typedef SubMapper DataMapper; 1035 typedef typename packet_traits<Scalar>::type Packet; 1036 1037 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1038 1039 EIGEN_DEVICE_FUNC 1040 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1041 Index depth, Index cols, Index stride = 0, 1042 Index offset = 0) const { 1043 eigen_assert(stride == 0); 1044 eigen_assert(offset == 0); 1045 1046 const int packet_size = 2; 1047 const Index packet_cols4 = (cols / 4) * 4; 1048 const Index peeled_k = (depth / packet_size) * packet_size; 1049 const bool non_standard_patches = rhs.nonStandardPatches(); 1050 1051 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1052 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1053 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1054 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1055 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1056 1057 Index k = 0; 1058 if (!non_standard_patches) { 1059 // FAST PATH: 1060 // Iterate over patch columns and rows if we know that a single 1061 // packet do not span across multiple rows or columns. 1062 if ((rhs.patchDepth() % packet_size) == 0) { 1063 const Index start_col = rhs.colOffset(); 1064 const Index max_col = rhs.maxCol(peeled_k); 1065 1066 for (Index c = start_col; c < max_col; ++c) { 1067 eigen_assert(k <= peeled_k); 1068 1069 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1070 const Index max_row = rhs.maxRow(peeled_k, c); 1071 1072 const bool pad_col0 = dm0.padCol(c); 1073 const bool pad_col1 = dm1.padCol(c); 1074 const bool pad_col2 = dm2.padCol(c); 1075 const bool pad_col3 = dm3.padCol(c); 1076 1077 // We can squeeze reads along the `row` and `depth` dimensions if 1078 // the row stride is `1`, which means that `row` and `depth` 1079 // dimensions are contiguous (two innermost dimensions). 1080 if (rhs.rowStride() == 1 && // 1081 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 1082 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 1083 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 1084 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 1085 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 1086 // Compute how many elements we can squeeze read. 1087 const Index start_depth = 1088 (c == start_col) ? rhs.depthOffset() : 0; 1089 1090 // Upper bound for the number of elements in the depth dimension 1091 // that we can squeeze read. 1092 const Index squeeze_length = 1093 (max_row - start_row) * rhs.patchDepth() - start_depth; 1094 1095 // Do not overshoot beyond the block size. 1096 const Index max_depth = 1097 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 1098 eigen_assert((max_depth - start_depth) % packet_size == 0); 1099 1100 const Index idx0 = dm0.baseIndex(start_row, c); 1101 const Index idx1 = dm1.baseIndex(start_row, c); 1102 const Index idx2 = dm2.baseIndex(start_row, c); 1103 const Index idx3 = dm3.baseIndex(start_row, c); 1104 1105 for (Index d = start_depth; d < max_depth; d += packet_size) { 1106 PacketBlock<Packet, 2> kernel0; 1107 PacketBlock<Packet, 2> kernel1; 1108 kernel0.packet[0] = rhs.packetNoPadding(d, idx0); 1109 kernel0.packet[1] = rhs.packetNoPadding(d, idx1); 1110 kernel1.packet[0] = rhs.packetNoPadding(d, idx2); 1111 kernel1.packet[1] = rhs.packetNoPadding(d, idx3); 1112 ptranspose(kernel0); 1113 ptranspose(kernel1); 1114 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1115 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1116 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1117 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1118 block += 4 * packet_size; 1119 k += packet_size; 1120 } 1121 1122 // Go to the next column. 1123 continue; 1124 } 1125 1126 // If we can't squeeze reads, process rows one by one. 1127 for (Index r = start_row; r < max_row; ++r) { 1128 eigen_assert(k <= peeled_k); 1129 1130 const bool pad0 = pad_col0 || dm0.padRow(r); 1131 const bool pad1 = pad_col1 || dm1.padRow(r); 1132 const bool pad2 = pad_col2 || dm2.padRow(r); 1133 const bool pad3 = pad_col3 || dm3.padRow(r); 1134 1135 const Index idx0 = dm0.baseIndex(r, c); 1136 const Index idx1 = dm1.baseIndex(r, c); 1137 const Index idx2 = dm2.baseIndex(r, c); 1138 const Index idx3 = dm3.baseIndex(r, c); 1139 1140 const Index start_depth = ((c == start_col) && (r == start_row)) 1141 ? rhs.depthOffset() 1142 : 0; 1143 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1144 eigen_assert((max_depth - start_depth) % packet_size == 0); 1145 1146 for (Index d = start_depth; d < max_depth; d += packet_size) { 1147 eigen_assert(k < peeled_k); 1148 PacketBlock<Packet, 2> kernel0; 1149 PacketBlock<Packet, 2> kernel1; 1150 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1151 : rhs.packetNoPadding(d, idx0); 1152 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1153 : rhs.packetNoPadding(d, idx1); 1154 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1155 : rhs.packetNoPadding(d, idx2); 1156 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1157 : rhs.packetNoPadding(d, idx3); 1158 ptranspose(kernel0); 1159 ptranspose(kernel1); 1160 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1161 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1162 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1163 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1164 block += 4 * packet_size; 1165 k += packet_size; 1166 } 1167 } 1168 } 1169 1170 // The loop above should fill peeled_k elements. 1171 eigen_assert(peeled_k == k); 1172 1173 } else { 1174 // Packet can span multiple rows or columns, so we have to go 1175 // though the slower "standard" path. 1176 for (; k < peeled_k; k += packet_size) { 1177 PacketBlock<Packet, 2> kernel0; 1178 PacketBlock<Packet, 2> kernel1; 1179 kernel0.packet[0] = dm0.loadPacketStandard(k); 1180 kernel0.packet[1] = dm1.loadPacketStandard(k); 1181 kernel1.packet[0] = dm2.loadPacketStandard(k); 1182 kernel1.packet[1] = dm3.loadPacketStandard(k); 1183 ptranspose(kernel0); 1184 ptranspose(kernel1); 1185 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1186 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1187 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1188 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1189 block += 4 * packet_size; 1190 } 1191 } 1192 } 1193 1194 // Copy the remaining coefficients of the column block after the peeled_k. 1195 if (!non_standard_patches) { 1196 for (; k < depth; k++) { 1197 block[0] = dm0.loadCoeffStandard(k); 1198 block[1] = dm1.loadCoeffStandard(k); 1199 block[2] = dm2.loadCoeffStandard(k); 1200 block[3] = dm3.loadCoeffStandard(k); 1201 block += 4; 1202 } 1203 } else { 1204 for (; k < depth; k++) { 1205 block[0] = dm0(k); 1206 block[1] = dm1(k); 1207 block[2] = dm2(k); 1208 block[3] = dm3(k); 1209 block += 4; 1210 } 1211 } 1212 } 1213 1214 // Copy the remaining columns one at a time (nr==1). 1215 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1216 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1217 for (Index k = 0; k < depth; k++) { 1218 *block = dm0(k); 1219 block += 1; 1220 } 1221 } 1222 } 1223 }; 1224 1225 // Special case for non-vectorized types such as float16. 1226 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1227 typename Device, typename Scalar, typename Index, 1228 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1229 bool inner_dim_reordered, int Alignment, int nr> 1230 struct gemm_pack_rhs< 1231 Scalar, Index, 1232 TensorContractionSubMapper< 1233 Scalar, Index, Rhs, 1234 TensorEvaluator< 1235 const TensorReshapingOp< 1236 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1237 Device>, 1238 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1239 Alignment>, 1240 nr, ColMajor, false, false> { 1241 typedef TensorContractionSubMapper< 1242 Scalar, Index, Rhs, 1243 TensorEvaluator< 1244 const TensorReshapingOp< 1245 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1246 Device>, 1247 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1248 Alignment> 1249 SubMapper; 1250 typedef SubMapper DataMapper; 1251 1252 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1253 1254 EIGEN_DEVICE_FUNC 1255 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1256 Index depth, Index cols, Index stride = 0, 1257 Index offset = 0) const { 1258 eigen_assert(stride == 0); 1259 eigen_assert(offset == 0); 1260 1261 const Index packet_cols4 = (cols / 4) * 4; 1262 1263 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1264 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1265 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1266 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1267 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1268 1269 if (!rhs.nonStandardPatches()) { 1270 for (Index k = 0; k < depth; k++) { 1271 block[0] = dm0.loadCoeffStandard(k); 1272 block[1] = dm1.loadCoeffStandard(k); 1273 block[2] = dm2.loadCoeffStandard(k); 1274 block[3] = dm3.loadCoeffStandard(k); 1275 block += 4; 1276 } 1277 } else { 1278 for (Index k = 0; k < depth; k++) { 1279 block[0] = dm0(k); 1280 block[1] = dm1(k); 1281 block[2] = dm2(k); 1282 block[3] = dm3(k); 1283 block += 4; 1284 } 1285 } 1286 } 1287 1288 // Copy the remaining columns one at a time (nr==1). 1289 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1290 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1291 for (Index k = 0; k < depth; k++) { 1292 *block = dm0(k); 1293 block += 1; 1294 } 1295 } 1296 } 1297 }; 1298 } // end namespace internal 1299 1300 /** SpatialConvolution 1301 * \ingroup CXX11_NeuralNetworks_Module 1302 * 1303 * \brief Applies a 2D convolution over a multichannel input image. 1304 * 1305 * The input parameter is expected to be a tensor with a rank of 3 or more 1306 * (channels, height, width, and optionally others) 1307 * The kernel parameter is expected to be a 4D tensor (filters, channels, 1308 * kernel_height, kernel_width) 1309 * The input and the kernel must both be in col-major layout. The result will 1310 * also be in col-major layout. 1311 * 1312 * If col_in_stride, row_in_stride > 1, then applies convolution with holes 1313 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input 1314 * pixels. 1315 * 1316 * The result can be assigned to a tensor of rank equal to the rank of the 1317 * input. The dimensions of the result will be filters, height, width (and 1318 * others if applicable). 1319 * 1320 * It is possible to swap the order of the width and height dimensions provided 1321 * that the same order is used in the input, the kernel, and the output. 1322 * 1323 * It is also possible to add an output kernel to the contraction, output 1324 * kernel is called by Eigen when it "finalizes" the block of an output tensor. 1325 * 1326 */ 1327 template <typename Input, typename Kernel, 1328 typename OutputKernel = const NoOpOutputKernel> 1329 EIGEN_DEVICE_FUNC 1330 EIGEN_ALWAYS_INLINE static const typename internal::conditional< 1331 internal::traits<Input>::Layout == ColMajor, 1332 TensorReshapingOp< 1333 const DSizes<typename internal::traits<Input>::Index, 1334 internal::traits<Input>::NumDimensions>, 1335 const TensorContractionOp< 1336 const array<IndexPair<typename internal::traits<Input>::Index>, 1337 1>, 1338 const TensorReshapingOp< 1339 const DSizes<typename internal::traits<Input>::Index, 2>, 1340 const Kernel>, 1341 const TensorReshapingOp< 1342 const DSizes<typename internal::traits<Input>::Index, 2>, 1343 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1344 const OutputKernel> >, 1345 TensorReshapingOp< 1346 const DSizes<typename internal::traits<Input>::Index, 1347 internal::traits<Input>::NumDimensions>, 1348 const TensorContractionOp< 1349 const array<IndexPair<typename internal::traits<Input>::Index>, 1350 1>, 1351 const TensorReshapingOp< 1352 const DSizes<typename internal::traits<Input>::Index, 2>, 1353 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1354 const TensorReshapingOp< 1355 const DSizes<typename internal::traits<Input>::Index, 2>, 1356 const Kernel>, 1357 const OutputKernel> > >::type 1358 SpatialConvolution(const Input& input, const Kernel& kernel, 1359 const Index row_stride = 1, const Index col_stride = 1, 1360 const PaddingType padding_type = PADDING_SAME, 1361 const Index row_in_stride = 1, 1362 const Index col_in_stride = 1, 1363 const OutputKernel& output_kernel = OutputKernel()) { 1364 typedef typename internal::traits<Input>::Index TensorIndex; 1365 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 1366 internal::traits<Input>::NumDimensions, 1367 internal::traits<Input>::Layout, TensorIndex> > 1368 in(input); 1369 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1370 internal::traits<Kernel>::NumDimensions, 1371 internal::traits<Kernel>::Layout, TensorIndex> > 1372 kern(kernel); 1373 1374 EIGEN_STATIC_ASSERT( 1375 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1376 YOU_MADE_A_PROGRAMMING_MISTAKE) 1377 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1378 1379 const int NumDims = internal::traits<Input>::NumDimensions; 1380 1381 // Number of filters to apply. This is the same as the output depth of the 1382 // result 1383 const TensorIndex kernelFilters = 1384 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; 1385 // Number of channels. This is the same as the input depth. 1386 const TensorIndex kernelChannels = 1387 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; 1388 const TensorIndex kernelRows = 1389 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; 1390 const TensorIndex kernelCols = 1391 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; 1392 1393 const Index kernelRowsEff = 1394 kernelRows + (kernelRows - 1) * (row_in_stride - 1); 1395 const Index kernelColsEff = 1396 kernelCols + (kernelCols - 1) * (col_in_stride - 1); 1397 1398 array<IndexPair<TensorIndex>, 1> contract_dims; 1399 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1400 1401 const TensorIndex InputRows = 1402 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1403 const TensorIndex InputCols = 1404 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1405 1406 TensorIndex out_height; 1407 TensorIndex out_width; 1408 switch (padding_type) { 1409 case PADDING_VALID: 1410 out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / 1411 static_cast<float>(row_stride)); 1412 out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / 1413 static_cast<float>(col_stride)); 1414 break; 1415 case PADDING_SAME: 1416 out_height = numext::ceil(InputRows / static_cast<float>(row_stride)); 1417 out_width = numext::ceil(InputCols / static_cast<float>(col_stride)); 1418 break; 1419 default: 1420 // Initialize unused variables to avoid a compiler warning 1421 out_height = 0; 1422 out_width = 0; 1423 eigen_assert(false && "unexpected padding"); 1424 } 1425 1426 // Molds the output of the patch extraction code into a 2d tensor: 1427 // - the first dimension (dims[0]): the patch values to be multiplied with the 1428 // kernels 1429 // - the second dimension (dims[1]): everything else 1430 DSizes<TensorIndex, 2> pre_contract_dims; 1431 if (isColMajor) { 1432 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; 1433 pre_contract_dims[1] = out_height * out_width; 1434 for (int i = 3; i < NumDims; ++i) { 1435 pre_contract_dims[1] *= in.dimension(i); 1436 } 1437 } else { 1438 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; 1439 pre_contract_dims[0] = out_height * out_width; 1440 for (int i = 0; i < NumDims - 3; ++i) { 1441 pre_contract_dims[0] *= in.dimension(i); 1442 } 1443 } 1444 1445 // Molds the output of the contraction into the shape expected by the used 1446 // (assuming this is ColMajor): 1447 // - 1st dim: kernel filters 1448 // - 2nd dim: output height 1449 // - 3rd dim: output width 1450 // - 4th dim and beyond: everything else including batch size 1451 DSizes<TensorIndex, NumDims> post_contract_dims; 1452 if (isColMajor) { 1453 post_contract_dims[0] = kernelFilters; 1454 post_contract_dims[1] = out_height; 1455 post_contract_dims[2] = out_width; 1456 for (int i = 3; i < NumDims; ++i) { 1457 post_contract_dims[i] = in.dimension(i); 1458 } 1459 } else { 1460 post_contract_dims[NumDims - 1] = kernelFilters; 1461 post_contract_dims[NumDims - 2] = out_height; 1462 post_contract_dims[NumDims - 3] = out_width; 1463 for (int i = 0; i < NumDims - 3; ++i) { 1464 post_contract_dims[i] = in.dimension(i); 1465 } 1466 } 1467 1468 DSizes<TensorIndex, 2> kernel_dims; 1469 if (isColMajor) { 1470 kernel_dims[0] = kernelFilters; 1471 kernel_dims[1] = kernelChannels * kernelRows * kernelCols; 1472 } else { 1473 kernel_dims[0] = kernelChannels * kernelRows * kernelCols; 1474 kernel_dims[1] = kernelFilters; 1475 } 1476 return choose( 1477 Cond<internal::traits<Input>::Layout == ColMajor>(), 1478 kernel.reshape(kernel_dims) 1479 .contract(input 1480 .extract_image_patches( 1481 kernelRows, kernelCols, row_stride, col_stride, 1482 row_in_stride, col_in_stride, padding_type) 1483 .reshape(pre_contract_dims), 1484 contract_dims, output_kernel) 1485 .reshape(post_contract_dims), 1486 input 1487 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, 1488 row_in_stride, col_in_stride, padding_type) 1489 .reshape(pre_contract_dims) 1490 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) 1491 .reshape(post_contract_dims)); 1492 } 1493 1494 } // end namespace Eigen 1495 1496 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 1497