1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 18 19 #include "tensorflow/core/kernels/eigen_convolution_helpers.h" 20 21 // Note this header is used in both TF and TFLite. 22 namespace Eigen { 23 24 namespace internal { 25 26 // WARNING: Most of the code here implicitly assumes that the matrix is in 27 // ColMajor layout. This is guaranteed by the tensor contraction (see 28 // TensorContraction.h). 29 // 30 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 31 // We don't want to actually extract image patches and reshape the result into 32 // a matrix (this involves allocating huge extra memory), so the patch 33 // extraction and reshape operations are implicit. 34 // 35 // TensorContractionInputMapper takes a matrix index and returns the coefficient 36 // (or the packet) of the "virtual tensor", that would be at that index if we 37 // were to actually reshape the result of patch extraction. 38 // 39 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 40 // at the given vertical and horizontal offsets. 41 // 42 // "Virtual matrix" dimensions: 43 // *0: kernelChannels * kernelRows * kernelCols; 44 // 1: out_height * out_width; * OTHERS (e.g batches, etc...) 45 // 46 // *) extracted patches are continuous in memory (innermost dimension assuming 47 // col major layout) 48 // 49 // With this dimensions: 50 // row - offset within a single patch (in code: patchId) 51 // col - index of the extracted patch (in code: patchIndex) 52 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 53 // 54 // TODO(ezhulenev): Consolidate this part of the code with the image patch 55 // extraction code since they are both very similar. 56 57 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 58 typename Device, typename Scalar_, typename Index, 59 typename nocontract_t, typename contract_t, int Side, int packet_size, 60 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 61 class TensorContractionInputMapper< 62 Scalar_, Index, Side, 63 TensorEvaluator< 64 const TensorReshapingOp<NewDimension, 65 const TensorImagePatchOp<Rows, Cols, ArgType> >, 66 Device>, 67 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 68 inner_dim_reordered, Alignment> { 69 public: 70 typedef Scalar_ Scalar; 71 72 typedef TensorContractionInputMapper< 73 Scalar, Index, Side, 74 TensorEvaluator< 75 const TensorReshapingOp< 76 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 77 Device>, 78 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 79 inner_dim_reordered, Alignment> 80 Self; 81 82 typedef TensorContractionSubMapper< 83 Scalar, Index, Side, 84 TensorEvaluator< 85 const TensorReshapingOp< 86 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 87 Device>, 88 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 89 inner_dim_reordered, Alignment> 90 SubMapper; 91 92 typedef SubMapper VectorMapper; 93 typedef SubMapper LinearMapper; 94 typedef typename packet_traits<Scalar>::type Packet; 95 96 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT; 97 98 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)99 TensorContractionInputMapper( 100 const TensorEvaluator< 101 const TensorReshapingOp< 102 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 103 Device>& tensor, 104 const nocontract_t&, const nocontract_t&, const contract_t&, 105 const contract_t&) 106 : m_impl(tensor.impl().impl()) { 107 Index patch_rows; 108 Index patch_depth; 109 if (internal::traits<ArgType>::Layout == ColMajor) { 110 patch_depth = tensor.impl().dimensions()[0]; 111 patch_rows = tensor.impl().dimensions()[1]; 112 m_patch_cols = tensor.impl().dimensions()[2]; 113 m_num_patches = tensor.impl().dimensions()[3]; 114 } else { 115 const size_t NumDims = tensor.impl().dimensions().size(); 116 patch_depth = tensor.impl().dimensions()[NumDims - 1]; 117 patch_rows = tensor.impl().dimensions()[NumDims - 2]; 118 m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; 119 m_num_patches = tensor.impl().dimensions()[NumDims - 4]; 120 } 121 122 // Strides for navigating through the single patch. 123 m_patch_row_stride = patch_depth; 124 m_patch_col_stride = patch_rows * m_patch_row_stride; 125 126 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 127 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 128 129 m_colStride = patch_rows; 130 131 m_outputRows = tensor.impl().outputRows(); 132 m_outputCols = tensor.impl().outputCols(); 133 m_row_strides = tensor.impl().userRowStride(); 134 m_col_strides = tensor.impl().userColStride(); 135 136 m_in_row_strides = tensor.impl().userInRowStride(); 137 m_in_col_strides = tensor.impl().userInColStride(); 138 139 if (internal::traits<ArgType>::Layout == ColMajor) { 140 m_inputRows = tensor.impl().impl().dimensions()[1]; 141 m_inputCols = tensor.impl().impl().dimensions()[2]; 142 } else { 143 const int NumDims = tensor.impl().impl().dimensions().size(); 144 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; 145 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; 146 } 147 148 m_rowInputStride = patch_depth; 149 m_colInputStride = patch_depth * m_inputRows; 150 m_patchInputStride = patch_depth * m_inputRows * m_inputCols; 151 152 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 153 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 154 155 m_fastPatchRowStride = 156 internal::TensorIntDivisor<Index>(m_patch_row_stride); 157 m_fastPatchColStride = 158 internal::TensorIntDivisor<Index>(m_patch_col_stride); 159 m_fastInputRowStride = 160 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 161 m_fastInputColStride = 162 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 163 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 164 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 165 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 166 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); 167 } 168 169 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)170 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 171 : m_impl(base_mapper.m_impl) { 172 m_patch_cols = base_mapper.m_patch_cols; 173 m_num_patches = base_mapper.m_num_patches; 174 175 m_patch_row_stride = base_mapper.m_patch_row_stride; 176 m_patch_col_stride = base_mapper.m_patch_col_stride; 177 178 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 179 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 180 181 m_colStride = base_mapper.m_colStride; 182 183 m_rowInputStride = base_mapper.m_rowInputStride; 184 m_colInputStride = base_mapper.m_colInputStride; 185 m_patchInputStride = base_mapper.m_patchInputStride; 186 187 m_inputRows = base_mapper.m_inputRows; 188 m_inputCols = base_mapper.m_inputCols; 189 190 m_outputRows = base_mapper.m_outputRows; 191 m_outputCols = base_mapper.m_outputCols; 192 m_row_strides = base_mapper.m_row_strides; 193 m_col_strides = base_mapper.m_col_strides; 194 195 m_in_row_strides = base_mapper.m_in_row_strides; 196 m_in_col_strides = base_mapper.m_in_col_strides; 197 198 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 199 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 200 201 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 202 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 203 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 204 m_fastInputColStride = base_mapper.m_fastInputColStride; 205 m_fastNumPatches = base_mapper.m_fastNumPatches; 206 m_fastColStride = base_mapper.m_fastColStride; 207 m_fastOutputRows = base_mapper.m_fastOutputRows; 208 m_fastDimZero = base_mapper.m_fastDimZero; 209 } 210 211 // If true, turns off some optimizations for loading packets since the image 212 // patches are "non-standard" such as there are non-trivial strides or 213 // inflations in the input. 214 EIGEN_DEVICE_FUNC nonStandardPatches()215 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 216 return m_in_row_strides != 1 || m_in_col_strides != 1 || 217 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 218 } 219 220 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)221 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 222 return SubMapper(*this, i, j); 223 } 224 225 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)226 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 227 return LinearMapper(*this, i, j); 228 } 229 230 EIGEN_DEVICE_FUNC operator()231 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 232 Index rowIndex, colIndex, otherIndex; 233 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 234 return loadCoeff(row, rowIndex, colIndex, otherIndex); 235 } 236 237 // Load the coefficient at the patchIndex location instead of the usual 238 // m_rowIndex, 239 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 240 // EIGEN_DEVICE_FUNC 241 EIGEN_DEVICE_FUNC operator()242 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 243 Index rowIndex, colIndex, otherIndex; 244 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 245 return loadCoeff(row, rowIndex, colIndex, otherIndex); 246 } 247 248 EIGEN_DEVICE_FUNC loadPacket(Index row)249 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 250 Index rowIndex, colIndex, otherIndex; 251 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 252 return loadPacket(row, rowIndex, colIndex, otherIndex); 253 } 254 255 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 256 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 257 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)258 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 259 Index rowIndex, colIndex, otherIndex; 260 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 261 return loadPacket(row, rowIndex, colIndex, otherIndex); 262 } 263 264 EIGEN_DEVICE_FUNC impl()265 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 266 return m_impl; 267 } 268 269 EIGEN_DEVICE_FUNC patchDepth()270 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } 271 EIGEN_DEVICE_FUNC patchRows()272 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } 273 EIGEN_DEVICE_FUNC patchCols()274 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 275 276 private: 277 friend class TensorContractionSubMapper< 278 Scalar, Index, Side, 279 TensorEvaluator< 280 const TensorReshapingOp< 281 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 282 Device>, 283 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 284 inner_dim_reordered, Alignment>; 285 286 // Load coefficient from a patch specified by the "within patch offset" 287 // (patchId) and the precomputed indices of the first element of the patch. 288 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)289 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, 290 Index colIndex, Index otherIndex) const { 291 // Find the offset of the element wrt the location of the first element. 292 const Index patchOffset = patchId / m_fastDimZero; 293 294 const Index colOffset = patchOffset / m_fastColStride; 295 const Index inputCol = colIndex + colOffset * m_in_col_strides; 296 const Index origInputCol = 297 (m_patch_col_inflate_strides == 1) 298 ? inputCol 299 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 300 301 const Index rowOffset = patchOffset - colOffset * m_colStride; 302 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 303 const Index origInputRow = 304 (m_patch_row_inflate_strides == 1) 305 ? inputRow 306 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 307 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || 308 origInputRow >= m_inputRows || 309 (inputCol != origInputCol * m_patch_col_inflate_strides) || 310 (inputRow != origInputRow * m_patch_row_inflate_strides)) { 311 return Scalar(0); 312 } 313 const Index depth = patchId - patchOffset * patchDepth(); 314 const Index inputIndex = depth + origInputRow * m_rowInputStride + 315 origInputCol * m_colInputStride + otherIndex; 316 return m_impl.coeff(inputIndex); 317 } 318 319 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 320 // and `in_strides` equal to 1 (template specialization without templates). 321 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)322 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, 323 Index colIndex, 324 Index otherIndex) const { 325 eigen_assert(!nonStandardPatches()); 326 327 // Find the offset of the element wrt the location of the first element. 328 const Index patchOffset = patchId / m_fastDimZero; 329 const Index colOffset = patchOffset / m_fastColStride; 330 const Index rowOffset = patchOffset - colOffset * m_colStride; 331 const Index inputCol = colIndex + colOffset; 332 const Index inputRow = rowIndex + rowOffset; 333 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 334 inputRow >= m_inputRows) { 335 return Scalar(0); 336 } 337 const Index depth = patchId - patchOffset * patchDepth(); 338 const Index inputIndex = depth + inputRow * m_rowInputStride + 339 inputCol * m_colInputStride + otherIndex; 340 return m_impl.coeff(inputIndex); 341 } 342 343 // Load packet from a patch specified by the "within patch offset" 344 // (patchId) and the precomputed indices of the first element of the patch. 345 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)346 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, 347 Index colIndex, 348 Index otherIndex) const { 349 const Index packetSize = internal::unpacket_traits<Packet>::size; 350 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 351 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 352 353 if (nonStandardPatches()) { 354 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 355 } 356 typedef decltype(m_impl) TensorEvaluatorT; 357 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, 358 colIndex, otherIndex); 359 } 360 361 // Helper function to load a 'partial' packet - this is the single column 362 // part of a packet that is split across two columns. In the 'partial' packet, 363 // the elements corresponding to the column (specified through colOffset) are 364 // loaded and the rest of the elements are zero-filled into the 'partial' 365 // packet. This function is called from loadPacketStandardFromTwoColumns(). 366 // This code path is exercised only when the packet type supports masked load 367 // and when the partial packet load is available in the TensorEvaluator. 368 EIGEN_DEVICE_FUNC loadPartialPacketStandard(Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset)369 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( 370 Index rowIndex, Index colIndex, Index otherIndex, Index patchId, 371 const Index span[], const Index patchOffsets[], Index colOffset) const { 372 const Index inputCol = colIndex + colOffset; 373 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride, 374 patchOffsets[1] - colOffset * m_colStride}; 375 const Index inputRows[2] = {rowIndex + rowOffsets[0], 376 rowIndex + rowOffsets[1]}; 377 378 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || 379 inputCol >= m_inputCols || inputCol < 0) { 380 // Partial packet is all zeros 381 return internal::pset1<Packet>(Scalar(0)); 382 } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 383 // From inputIndex-span[0], we need to load elements starting from index 384 // span[0] all the way upto (and including) span[1]. 385 const Index depth = patchId - patchOffsets[0] * patchDepth(); 386 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 387 inputCol * m_colInputStride + otherIndex; 388 return m_impl.template partialPacket<Packet>( 389 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1)); 390 } else { 391 // Using slow path for this partial packet. 392 // We need to load elements starting from index span[0] all the way upto 393 // (and including) span[1]. We split this load into 3 parts: 394 // 0 : span[0]-1 - Zeros will be loaded for these indices 395 // span[0] : span[1] - Elements will be loaded here for these indices 396 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices 397 const Index packetSize = internal::unpacket_traits<Packet>::size; 398 EIGEN_ALIGN_MAX 399 typename internal::remove_const<Scalar>::type values[packetSize]; 400 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); 401 for (int i = span[0]; i < span[1] + 1; ++i) 402 values[i] = 403 loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex); 404 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); 405 return internal::pload<Packet>(values); 406 } 407 } 408 409 // Helper function to load a packet that is split across two columns. 410 // If required, this function is called from loadPacketStandard() when the 411 // packet type supports masked load and when the partial packet load is 412 // available in the TensorEvaluator. 413 EIGEN_DEVICE_FUNC loadPacketStandardFromTwoColumns(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[])414 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns( 415 Index patchId, Index rowIndex, Index colIndex, Index otherIndex, 416 const Index patchOffsets[], const Index colOffsets[]) const { 417 eigen_assert(colOffsets[1] == colOffsets[0] + 1); 418 const Index packetSize = internal::unpacket_traits<Packet>::size; 419 420 // Packet to load will be split into 2 parts where each part spans a single 421 // column. First determine where to split. 422 const Index patchIdSplit = 423 ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1; 424 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; 425 426 // patchIds[i]: patchId corresponding to partial packet i 427 // spans[i]: Start and end indices corresponding to the elements 428 // to be loaded for partial packet i 429 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i 430 const Index patchIds[2] = {patchId, patchIdSplit + 1}; 431 const Index spans[2][2] = {{0, patchIdSplit - patchId}, 432 {patchIdSplit - patchId + 1, packetSize - 1}}; 433 const Index patchOffsets2Cols[2][2] = { 434 {patchOffsets[0], patchOffsetSplit}, 435 {patchOffsetSplit + 1, patchOffsets[1]}}; 436 437 // Load partial packets and do bit-wise OR to generate required packet 438 return internal::por<Packet>( 439 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], 440 spans[0], patchOffsets2Cols[0], 441 colOffsets[0]), 442 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], 443 spans[1], patchOffsets2Cols[1], 444 colOffsets[1])); 445 } 446 447 // Helper function to load a packet that is present in a single columns. 448 // If required, this function is called from loadPacketStandard(). 449 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumn(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index inputCols[])450 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn( 451 Index patchId, Index rowIndex, Index colIndex, Index otherIndex, 452 const Index patchOffsets[], const Index colOffsets[], 453 const Index inputCols[]) const { 454 eigen_assert(colOffsets[0] == colOffsets[1]); 455 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride, 456 patchOffsets[1] - colOffsets[1] * m_colStride}; 457 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 458 const Index inputRows[2] = {rowIndex + rowOffsets[0], 459 rowIndex + rowOffsets[1]}; 460 461 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 462 // all zeros 463 return internal::pset1<Packet>(Scalar(0)); // all zeros 464 } 465 466 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 467 // no padding 468 const Index depth = patchId - patchOffsets[0] * patchDepth(); 469 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 470 inputCols[0] * m_colInputStride + otherIndex; 471 return m_impl.template packet<Unaligned>(inputIndex); 472 } 473 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 474 } 475 476 // Load standard packet from a patch specified by the "within patch offset" 477 // (patchId) and the precomputed indices of the first element of the patch. 478 // This function will be called if partial packet loading is not available 479 // for the TensorEvaluator or if the packet type does not support masked 480 // load. 481 template <typename PacketT, typename TensorEvaluatorT> 482 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 483 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 484 PacketT>::type loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)485 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, 486 Index otherIndex) const { 487 const Index packetSize = internal::unpacket_traits<Packet>::size; 488 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 489 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 490 491 eigen_assert(!nonStandardPatches()); 492 493 if ((patchDepth() % packetSize) == 0) { 494 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 495 } 496 497 // Offsets and input calculation here are identical to 498 // loadCoeffStandard(...), but repeated twice. 499 const Index patchOffsets[2] = {patchId / m_fastDimZero, 500 (patchId + packetSize - 1) / m_fastDimZero}; 501 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 502 patchOffsets[1] / m_fastColStride}; 503 const Index inputCols[2] = {colIndex + colOffsets[0], 504 colIndex + colOffsets[1]}; 505 506 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 507 // all zeros 508 return internal::pset1<Packet>(Scalar(0)); 509 } 510 if (inputCols[0] == inputCols[1]) { 511 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, 512 otherIndex, patchOffsets, 513 colOffsets, inputCols); 514 } 515 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 516 } 517 518 // Load standard packet from a patch specified by the "within patch offset" 519 // (patchId) and the precomputed indices of the first element of the patch. 520 // This function will be called if partial packet loading is available for 521 // the TensorEvaluator and if the packet type supports masked load. 522 // The only difference between this and the other case is that if the packet 523 // to load is split across two columns, then in this case instead of going to 524 // the slow (element-by-element) load, we load two packets - each containing 525 // elements from one of the columns (rest of the elements of the packets are 526 // zeroes), and then combine these two packets to generate the required 527 // packet. The idea is to enable fast load (if possible) of these 'partial' 528 // packets. 529 template <typename PacketT, typename TensorEvaluatorT> 530 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 531 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 532 PacketT>::type loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)533 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, 534 Index otherIndex) const { 535 const Index packetSize = internal::unpacket_traits<PacketT>::size; 536 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 537 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 538 539 eigen_assert(!nonStandardPatches()); 540 541 if ((patchDepth() % packetSize) == 0) { 542 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 543 } 544 545 // Offsets and input calculation here are identical to 546 // loadCoeffStandard(...), but repeated twice. 547 const Index patchOffsets[2] = {patchId / m_fastDimZero, 548 (patchId + packetSize - 1) / m_fastDimZero}; 549 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 550 patchOffsets[1] / m_fastColStride}; 551 const Index inputCols[2] = {colIndex + colOffsets[0], 552 colIndex + colOffsets[1]}; 553 554 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 555 // all zeros 556 return internal::pset1<PacketT>(Scalar(0)); 557 } 558 if (inputCols[0] == inputCols[1]) { 559 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, 560 otherIndex, patchOffsets, 561 colOffsets, inputCols); 562 } 563 if (inputCols[1] == inputCols[0] + 1) { 564 return loadPacketStandardFromTwoColumns( 565 patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets); 566 } 567 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 568 } 569 570 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)571 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, 572 Index colIndex, 573 Index otherIndex) const { 574 const Index packetSize = internal::unpacket_traits<Packet>::size; 575 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 576 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 577 578 eigen_assert(!nonStandardPatches()); 579 eigen_assert((patchDepth() % packetSize) == 0); 580 // Find the offset of the element wrt the location of the first element. 581 const Index patchOffset = patchId / m_fastDimZero; 582 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 583 584 const Index colOffset = patchOffset / m_fastColStride; 585 const Index rowOffset = patchOffset - colOffset * m_colStride; 586 const Index inputCol = colIndex + colOffset; 587 const Index inputRow = rowIndex + rowOffset; 588 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || 589 inputRow >= m_inputRows) { 590 // all zeros 591 return internal::pset1<Packet>(Scalar(0)); 592 } 593 // no padding 594 const Index depth = patchId - patchOffset * patchDepth(); 595 const Index inputIndex = depth + inputRow * m_rowInputStride + 596 inputCol * m_colInputStride + otherIndex; 597 return m_impl.template packet<Unaligned>(inputIndex); 598 } 599 packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)600 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( 601 Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { 602 const int packetSize = internal::unpacket_traits<Packet>::size; 603 EIGEN_ALIGN_MAX 604 typename internal::remove_const<Scalar>::type values[packetSize]; 605 for (int i = 0; i < packetSize; ++i) { 606 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); 607 } 608 Packet rslt = internal::pload<Packet>(values); 609 return rslt; 610 } 611 computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 613 Index patchIndex, Index& rowIndex, Index& colIndex, 614 Index& otherIndex) const { 615 const size_t NumInputDims = array_size< 616 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 617 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; 618 const Index patch2DIndex = (NumInputDims == 3) 619 ? patchIndex 620 : (patchIndex - otherIndex * m_num_patches); 621 otherIndex *= m_patchInputStride; 622 colIndex = patch2DIndex / m_fastOutputRows; 623 rowIndex = patch2DIndex - colIndex * m_outputRows; 624 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 625 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 626 } 627 628 Index m_patch_cols; // number of columns in the patch 629 Index m_num_patches; // number of patches to extract. 630 631 // Strides for navigating through the single patch. 632 Index m_patch_row_stride; 633 Index m_patch_col_stride; 634 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 635 internal::TensorIntDivisor<Index> m_fastPatchColStride; 636 637 Index m_patch_row_inflate_strides; // the strides for row inflation in the 638 // image patch 639 Index m_patch_col_inflate_strides; // the strides for col inflation in the 640 // image patch 641 // Fast representation of inflation strides. 642 internal::TensorIntDivisor<Index> m_fastInputRowStride; 643 internal::TensorIntDivisor<Index> m_fastInputColStride; 644 645 Index m_otherStride; 646 Index m_colStride; 647 internal::TensorIntDivisor<Index> m_fastNumPatches; 648 internal::TensorIntDivisor<Index> m_fastColStride; 649 650 Index m_rowInputStride; // row stride in the input tensor 651 Index m_colInputStride; // col stride in the input tensor 652 Index m_patchInputStride; // patch stride in the input tensor 653 654 Index m_inputRows; // Number of rows in the input tensor 655 Index m_inputCols; // Number of cols in the input tensor 656 657 Index m_outputRows; // Number of convolution output rows 658 Index m_outputCols; // Number of convolution output column 659 660 Index m_row_strides; // User specified row stride 661 Index m_col_strides; // User specified col stride 662 663 Index m_in_row_strides; // User specified input row stride 664 Index m_in_col_strides; // User specified input col stride 665 666 Index m_rowPaddingTop; // Row padding 667 Index m_colPaddingLeft; // Column padding 668 669 internal::TensorIntDivisor<Index> m_fastOutputRows; 670 internal::TensorIntDivisor<Index> m_fastDimZero; 671 672 const TensorEvaluator<ArgType, Device> m_impl; 673 }; 674 675 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 676 typename Device, typename Scalar, typename Index, 677 typename nocontract_t, typename contract_t, int Side, int packet_size, 678 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 679 class TensorContractionSubMapper< 680 Scalar, Index, Side, 681 TensorEvaluator< 682 const TensorReshapingOp<NewDimension, 683 const TensorImagePatchOp<Rows, Cols, ArgType> >, 684 Device>, 685 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 686 inner_dim_reordered, Alignment> { 687 public: 688 typedef typename packet_traits<Scalar>::type Packet; 689 typedef typename packet_traits<Scalar>::half HalfPacket; 690 691 typedef TensorContractionInputMapper< 692 Scalar, Index, Side, 693 TensorEvaluator< 694 const TensorReshapingOp< 695 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 696 Device>, 697 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 698 inner_dim_reordered, Alignment> 699 ParentMapper; 700 701 typedef TensorContractionSubMapper< 702 Scalar, Index, Side, 703 TensorEvaluator< 704 const TensorReshapingOp< 705 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 706 Device>, 707 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 708 inner_dim_reordered, Alignment> 709 Self; 710 711 typedef Self LinearMapper; 712 713 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT; 714 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)715 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 716 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 717 : m_depth_offset(vert_offset), 718 m_col_offset(horiz_offset), 719 m_base_mapper(base_mapper) { 720 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 721 m_otherIndex); 722 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)723 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 724 const Self& base_mapper, Index vert_offset, Index horiz_offset) 725 : m_depth_offset(vert_offset + base_mapper.m_depth_offset), 726 m_col_offset(horiz_offset + base_mapper.m_col_offset), 727 m_base_mapper(base_mapper.m_base_mapper) { 728 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 729 m_otherIndex); 730 } operator()731 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 732 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, 733 m_otherIndex); 734 } operator()735 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 736 Index j) const { 737 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 738 } 739 loadPacket(Index i)740 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 741 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, 742 m_otherIndex); 743 } loadPacket(Index i,Index j)744 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 745 Index j) const { 746 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 747 j + m_col_offset); 748 } 749 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)750 loadCoeffStandard(Index i) const { 751 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, 752 m_colIndex, m_otherIndex); 753 } 754 loadPacketFast(Index i)755 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 756 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, 757 m_colIndex, m_otherIndex); 758 } 759 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)760 loadPacketStandard(Index i) const { 761 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; 762 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>( 763 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); 764 } 765 template <typename Packet> aligned(Index)766 EIGEN_DEVICE_FUNC bool aligned(Index) const { 767 return false; 768 } 769 770 EIGEN_DEVICE_FUNC nonStandardPatches()771 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 772 return m_base_mapper.nonStandardPatches(); 773 } 774 775 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth 776 // index respectively that fits into the peeled_k elements starting at 777 // m_depth_offset. 778 779 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)780 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 781 const Index max_col = 782 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / 783 fastPatchColStride(); 784 return std::min<Index>(1 + max_col, patchCols()); 785 } 786 787 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)788 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 789 const Index col) const { 790 const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - 791 col * patchColStride()) / 792 fastPatchRowStride(); 793 return std::min<Index>(1 + max_row, patchRows()); 794 } 795 796 EIGEN_DEVICE_FUNC maxDepth(const Index peeled_k,const Index col,Index row)797 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, 798 Index row) const { 799 const Index max_depth = m_depth_offset + peeled_k - // 800 col * patchColStride() - // 801 row * patchRowStride(); 802 return std::min<Index>(max_depth, patchDepth()); 803 } 804 805 // MaxDepth uses only the remaining number of elements in the peeled_k. 806 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)807 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 808 const Index start_depth) const { 809 return std::min<Index>(start_depth + num_elements, patchDepth()); 810 } 811 812 // Every register matters in this code, so sometimes to prevent register 813 // spilling, instead of the variable that you would expect to see, we use 814 // another one, that is guaranteed to have the same value. E.g. patch depth is 815 // always the same as input depth, and it's also the same as input row stride. 816 // Bunch of other parameters have similar relations. 817 818 typedef internal::TensorIntDivisor<Index> IndexDivisor; 819 820 EIGEN_DEVICE_FUNC patchDepth()821 EIGEN_ALWAYS_INLINE Index patchDepth() const { 822 return m_base_mapper.m_rowInputStride; 823 } 824 EIGEN_DEVICE_FUNC patchRows()825 EIGEN_ALWAYS_INLINE Index patchRows() const { 826 return m_base_mapper.m_colStride; 827 } 828 EIGEN_DEVICE_FUNC patchCols()829 EIGEN_ALWAYS_INLINE Index patchCols() const { 830 return m_base_mapper.m_patch_cols; 831 } 832 833 EIGEN_DEVICE_FUNC patchRowStride()834 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 835 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 836 "Patch depth must be equal to patch row stride."); 837 return patchDepth(); 838 } 839 EIGEN_DEVICE_FUNC patchColStride()840 EIGEN_ALWAYS_INLINE Index patchColStride() const { 841 return m_base_mapper.m_patch_col_stride; 842 } 843 844 EIGEN_DEVICE_FUNC fastPatchRowStride()845 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 846 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 847 "Patch depth must be equal to patch row stride."); 848 return m_base_mapper.m_fastDimZero; // patch_depth 849 } 850 EIGEN_DEVICE_FUNC fastPatchColStride()851 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 852 return m_base_mapper.m_fastPatchColStride; 853 } 854 855 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)856 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 857 const Index baseIndex) const { 858 const Index inputIndex = depth + baseIndex; 859 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 860 } 861 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)862 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 863 const Index baseIndex) const { 864 const Index inputIndex = depth + baseIndex; 865 return m_base_mapper.m_impl.coeff(inputIndex); 866 } 867 template <typename PacketT = Packet> 868 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 869 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 870 PacketT>::type partialPacketNoPadding(const Index depth,const Index baseIndex,Index num_coeffs)871 partialPacketNoPadding(const Index depth, const Index baseIndex, 872 Index num_coeffs) const { 873 const Index inputIndex = depth + baseIndex; 874 return m_base_mapper.m_impl.template partialPacket<PacketT>( 875 inputIndex, mask<PacketT>(0, num_coeffs)); 876 } 877 EIGEN_DEVICE_FUNC hasPadding()878 EIGEN_ALWAYS_INLINE bool hasPadding() const { 879 // TODO(ezhulenev): It does seems that for inflated filter it's still 880 // possible to guarantee "no padding or skipping" for non-standard packing. 881 if (nonStandardPatches()) return true; 882 883 // Non zero padding before. 884 if (m_base_mapper.m_rowPaddingTop > 0) return true; 885 if (m_base_mapper.m_colPaddingLeft > 0) return true; 886 887 // Non zero padding after in rows. 888 const Index last_row = 889 (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides; 890 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true; 891 892 // Non zero padding after in cols. 893 const Index last_col = 894 (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides; 895 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true; 896 897 return false; 898 } 899 EIGEN_DEVICE_FUNC padRow(const Index row)900 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 901 const Index r = m_rowIndex + row; 902 return r < 0 || r >= m_base_mapper.m_inputRows; 903 } 904 EIGEN_DEVICE_FUNC padAnyRow(const Index first_row,const Index last_row)905 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, 906 const Index last_row) const { 907 return m_rowIndex + first_row < 0 || 908 m_rowIndex + last_row >= m_base_mapper.m_inputRows; 909 } 910 EIGEN_DEVICE_FUNC padOrSkipRow(const Index row,Index * orig_row)911 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, 912 Index* orig_row) const { 913 eigen_assert(nonStandardPatches()); 914 915 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides; 916 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1) 917 ? input_row 918 : ((input_row >= 0) 919 ? (input_row / m_base_mapper.m_fastInputRowStride) 920 : 0); 921 922 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) || 923 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides); 924 } 925 EIGEN_DEVICE_FUNC padCol(const Index col)926 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 927 const Index c = m_colIndex + col; 928 return c < 0 || c >= m_base_mapper.m_inputCols; 929 } 930 EIGEN_DEVICE_FUNC padOrSkipCol(const Index col,Index * orig_col)931 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, 932 Index* orig_col) const { 933 eigen_assert(nonStandardPatches()); 934 935 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides; 936 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1) 937 ? input_col 938 : ((input_col >= 0) 939 ? (input_col / m_base_mapper.m_fastInputColStride) 940 : 0); 941 942 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) || 943 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides); 944 } 945 EIGEN_DEVICE_FUNC baseIndex(const Index row,const Index col)946 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { 947 const Index r = m_rowIndex + row; 948 const Index c = m_colIndex + col; 949 return r * m_base_mapper.m_rowInputStride + 950 c * m_base_mapper.m_colInputStride + m_otherIndex; 951 } 952 // Compute a base index when original input row and column were precomputed 953 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches. 954 EIGEN_DEVICE_FUNC origBaseIndex(const Index orig_row,const Index orig_col)955 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, 956 const Index orig_col) const { 957 return orig_row * m_base_mapper.m_rowInputStride + 958 orig_col * m_base_mapper.m_colInputStride + m_otherIndex; 959 } 960 961 EIGEN_DEVICE_FUNC rowStride()962 EIGEN_ALWAYS_INLINE Index rowStride() const { 963 return m_base_mapper.m_row_strides; 964 } 965 EIGEN_DEVICE_FUNC colStride()966 EIGEN_ALWAYS_INLINE Index colStride() const { 967 return m_base_mapper.m_col_strides; 968 } 969 970 EIGEN_DEVICE_FUNC rowOffset()971 EIGEN_ALWAYS_INLINE Index rowOffset() const { 972 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 973 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 974 return patchOffset - colOffset * m_base_mapper.m_colStride; 975 } 976 977 EIGEN_DEVICE_FUNC colOffset()978 EIGEN_ALWAYS_INLINE Index colOffset() const { 979 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 980 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 981 return colOffset; 982 } 983 984 EIGEN_DEVICE_FUNC depthOffset()985 EIGEN_ALWAYS_INLINE Index depthOffset() const { 986 return m_depth_offset % patchDepth(); 987 } 988 989 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)990 getLinearMapper(Index i, Index j) const { 991 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 992 } 993 994 private: 995 Index m_depth_offset; // First row in the input matrix 996 Index m_col_offset; // First col in the input matrix 997 998 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 999 // indices for the first element in a patch specified by col_offset 1000 // (see computeBaseIndices(...) for details). 1001 Index m_rowIndex; 1002 Index m_colIndex; 1003 Index m_otherIndex; 1004 1005 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 1006 // performs better in benchmarks. 1007 }; 1008 1009 // Arrange a block of the right input matrix (in our case it's always a "virtual 1010 // matrix" constructed from extracted image patches) in contiguous memory. 1011 // 1012 // Given column major input (A0 beside A1 in memory): 1013 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 1014 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 1015 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 1016 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 1017 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 1018 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 1019 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 1020 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 1021 // A8 ... 1022 // ... 1023 // 1024 // *) A, B, C, ... - patches extracted from the original input. 1025 // *) A0, A1, A2 ... - values from the same patch at different offsets. 1026 // 1027 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 1028 // A0 B0 C0 D0 A1 B1 C1 D1 ... 1029 // E0 F0 G0 H0 E1 F1 G1 H1 ... 1030 // ... 1031 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 1032 // 1033 // This traversal order must be the same as in default gemm_pack_rhs defined in 1034 // GeneralBlockPanelKernel.h. 1035 // 1036 // *) nr - number of registers along the 'n' dimension. 1037 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 1038 // Multiplication" paper. 1039 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1040 typename Device, typename Scalar, typename Index, 1041 typename nocontract_t, typename contract_t, int packet_size, 1042 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 1043 int nr> 1044 struct gemm_pack_rhs< 1045 Scalar, Index, 1046 TensorContractionSubMapper< 1047 Scalar, Index, Rhs, 1048 TensorEvaluator< 1049 const TensorReshapingOp< 1050 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1051 Device>, 1052 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1053 inner_dim_reordered, Alignment>, 1054 nr, ColMajor, false, false> { 1055 typedef TensorContractionSubMapper< 1056 Scalar, Index, Rhs, 1057 TensorEvaluator< 1058 const TensorReshapingOp< 1059 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1060 Device>, 1061 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1062 inner_dim_reordered, Alignment> 1063 SubMapper; 1064 typedef SubMapper DataMapper; 1065 typedef typename packet_traits<Scalar>::type Packet; 1066 1067 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1068 1069 EIGEN_DEVICE_FUNC 1070 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1071 Index depth, Index cols, Index stride = 0, 1072 Index offset = 0) const { 1073 eigen_assert(stride == 0); 1074 eigen_assert(offset == 0); 1075 1076 const Index packet_cols4 = (cols / 4) * 4; 1077 const Index peeled_k = (depth / packet_size) * packet_size; 1078 const bool non_standard_patches = rhs.nonStandardPatches(); 1079 1080 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1081 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1082 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1083 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1084 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1085 1086 Index k = 0; 1087 if ((packet_size % 4) == 0 && !non_standard_patches) { 1088 // FAST PATH: 1089 // Iterate over patch columns and rows, if we know that a single 1090 // packet do not span across multiple rows or columns. 1091 if ((rhs.patchDepth() % packet_size) == 0) { 1092 const Index start_col = rhs.colOffset(); 1093 const Index max_col = rhs.maxCol(peeled_k); 1094 1095 for (Index c = start_col; c < max_col; ++c) { 1096 eigen_assert(k <= peeled_k); 1097 1098 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1099 const Index max_row = rhs.maxRow(peeled_k, c); 1100 1101 const bool pad_col0 = dm0.padCol(c); 1102 const bool pad_col1 = dm1.padCol(c); 1103 const bool pad_col2 = dm2.padCol(c); 1104 const bool pad_col3 = dm3.padCol(c); 1105 1106 // Check if we can squeeze reads along the `row` and `depth` 1107 // dimensions (two innermost dimensions). 1108 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 1109 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 1110 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 1111 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 1112 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 1113 // Compute how many elements we can squeeze read. 1114 const Index start_depth = 1115 (c == start_col) ? rhs.depthOffset() : 0; 1116 1117 // Upper bound for the number of elements in the depth dimension 1118 // that we can squeeze read. 1119 const Index squeeze_length = 1120 (max_row - start_row) * rhs.patchDepth() - start_depth; 1121 1122 // Do not overshoot beyond the block size. 1123 const Index max_depth = 1124 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 1125 eigen_assert((max_depth - start_depth) % packet_size == 0); 1126 1127 const Index idx0 = dm0.baseIndex(start_row, c); 1128 const Index idx1 = dm1.baseIndex(start_row, c); 1129 const Index idx2 = dm2.baseIndex(start_row, c); 1130 const Index idx3 = dm3.baseIndex(start_row, c); 1131 1132 for (Index d = start_depth; d < max_depth; d += packet_size) { 1133 eigen_assert(k < peeled_k); 1134 PacketBlock<Packet, 4> kernel; 1135 kernel.packet[0] = rhs.packetNoPadding(d, idx0); 1136 kernel.packet[1] = rhs.packetNoPadding(d, idx1); 1137 kernel.packet[2] = rhs.packetNoPadding(d, idx2); 1138 kernel.packet[3] = rhs.packetNoPadding(d, idx3); 1139 ptranspose(kernel); 1140 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1141 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1142 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1143 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1144 block += 4 * packet_size; 1145 k += packet_size; 1146 } 1147 1148 // Go to the next column. 1149 continue; 1150 } 1151 1152 // If we can't squeeze reads, process rows one by one. 1153 for (Index r = start_row; r < max_row; ++r) { 1154 eigen_assert(k <= peeled_k); 1155 1156 const bool pad0 = pad_col0 || dm0.padRow(r); 1157 const bool pad1 = pad_col1 || dm1.padRow(r); 1158 const bool pad2 = pad_col2 || dm2.padRow(r); 1159 const bool pad3 = pad_col3 || dm3.padRow(r); 1160 1161 const Index idx0 = dm0.baseIndex(r, c); 1162 const Index idx1 = dm1.baseIndex(r, c); 1163 const Index idx2 = dm2.baseIndex(r, c); 1164 const Index idx3 = dm3.baseIndex(r, c); 1165 1166 const Index start_depth = ((c == start_col) && (r == start_row)) 1167 ? rhs.depthOffset() 1168 : 0; 1169 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1170 eigen_assert((max_depth - start_depth) % packet_size == 0); 1171 1172 for (Index d = start_depth; d < max_depth; d += packet_size) { 1173 eigen_assert(k < peeled_k); 1174 PacketBlock<Packet, 4> kernel; 1175 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1176 : rhs.packetNoPadding(d, idx0); 1177 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1178 : rhs.packetNoPadding(d, idx1); 1179 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 1180 : rhs.packetNoPadding(d, idx2); 1181 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 1182 : rhs.packetNoPadding(d, idx3); 1183 ptranspose(kernel); 1184 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1185 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1186 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1187 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1188 block += 4 * packet_size; 1189 k += packet_size; 1190 } 1191 } 1192 } 1193 1194 // The loop above should fill peeled_k elements. 1195 eigen_assert(peeled_k == k); 1196 1197 } else { 1198 for (; k < peeled_k; k += packet_size) { 1199 PacketBlock<Packet, 4> kernel; 1200 kernel.packet[0] = dm0.loadPacketStandard(k); 1201 kernel.packet[1] = dm1.loadPacketStandard(k); 1202 kernel.packet[2] = dm2.loadPacketStandard(k); 1203 kernel.packet[3] = dm3.loadPacketStandard(k); 1204 ptranspose(kernel); 1205 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1206 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1207 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1208 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1209 block += 4 * packet_size; 1210 } 1211 } 1212 } 1213 1214 // Copy the remaining coefficients of the column block after the peeled_k. 1215 if (!rhs.nonStandardPatches()) { 1216 for (; k < depth; k++) { 1217 block[0] = dm0.loadCoeffStandard(k); 1218 block[1] = dm1.loadCoeffStandard(k); 1219 block[2] = dm2.loadCoeffStandard(k); 1220 block[3] = dm3.loadCoeffStandard(k); 1221 block += 4; 1222 } 1223 } else { 1224 for (; k < depth; k++) { 1225 block[0] = dm0(k); 1226 block[1] = dm1(k); 1227 block[2] = dm2(k); 1228 block[3] = dm3(k); 1229 block += 4; 1230 } 1231 } 1232 } 1233 1234 // copy the remaining columns one at a time (nr==1) 1235 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1236 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1237 for (Index k = 0; k < depth; k++) { 1238 *block = dm0(k); 1239 block += 1; 1240 } 1241 } 1242 } 1243 }; 1244 1245 // Template specialization for packet_size = 2. We must special-case packet 1246 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1247 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1248 typename Device, typename Scalar, typename Index, 1249 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1250 bool inner_dim_reordered, int Alignment, int nr> 1251 struct gemm_pack_rhs< 1252 Scalar, Index, 1253 TensorContractionSubMapper< 1254 Scalar, Index, Rhs, 1255 TensorEvaluator< 1256 const TensorReshapingOp< 1257 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1258 Device>, 1259 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1260 Alignment>, 1261 nr, ColMajor, false, false> { 1262 typedef TensorContractionSubMapper< 1263 Scalar, Index, Rhs, 1264 TensorEvaluator< 1265 const TensorReshapingOp< 1266 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1267 Device>, 1268 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1269 Alignment> 1270 SubMapper; 1271 typedef SubMapper DataMapper; 1272 typedef typename packet_traits<Scalar>::type Packet; 1273 1274 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1275 1276 EIGEN_DEVICE_FUNC 1277 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1278 Index depth, Index cols, Index stride = 0, 1279 Index offset = 0) const { 1280 eigen_assert(stride == 0); 1281 eigen_assert(offset == 0); 1282 1283 const int packet_size = 2; 1284 const Index packet_cols4 = (cols / 4) * 4; 1285 const Index peeled_k = (depth / packet_size) * packet_size; 1286 const bool non_standard_patches = rhs.nonStandardPatches(); 1287 1288 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1289 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1290 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1291 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1292 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1293 1294 Index k = 0; 1295 if (!non_standard_patches) { 1296 // FAST PATH: 1297 // Iterate over patch columns and rows if we know that a single 1298 // packet do not span across multiple rows or columns. 1299 if ((rhs.patchDepth() % packet_size) == 0) { 1300 const Index start_col = rhs.colOffset(); 1301 const Index max_col = rhs.maxCol(peeled_k); 1302 1303 for (Index c = start_col; c < max_col; ++c) { 1304 eigen_assert(k <= peeled_k); 1305 1306 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1307 const Index max_row = rhs.maxRow(peeled_k, c); 1308 1309 const bool pad_col0 = dm0.padCol(c); 1310 const bool pad_col1 = dm1.padCol(c); 1311 const bool pad_col2 = dm2.padCol(c); 1312 const bool pad_col3 = dm3.padCol(c); 1313 1314 // We can squeeze reads along the `row` and `depth` dimensions if 1315 // the row stride is `1`, which means that `row` and `depth` 1316 // dimensions are contiguous (two innermost dimensions). 1317 if (rhs.rowStride() == 1 && // 1318 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 1319 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 1320 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 1321 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 1322 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 1323 // Compute how many elements we can squeeze read. 1324 const Index start_depth = 1325 (c == start_col) ? rhs.depthOffset() : 0; 1326 1327 // Upper bound for the number of elements in the depth dimension 1328 // that we can squeeze read. 1329 const Index squeeze_length = 1330 (max_row - start_row) * rhs.patchDepth() - start_depth; 1331 1332 // Do not overshoot beyond the block size. 1333 const Index max_depth = 1334 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 1335 eigen_assert((max_depth - start_depth) % packet_size == 0); 1336 1337 const Index idx0 = dm0.baseIndex(start_row, c); 1338 const Index idx1 = dm1.baseIndex(start_row, c); 1339 const Index idx2 = dm2.baseIndex(start_row, c); 1340 const Index idx3 = dm3.baseIndex(start_row, c); 1341 1342 for (Index d = start_depth; d < max_depth; d += packet_size) { 1343 PacketBlock<Packet, 2> kernel0; 1344 PacketBlock<Packet, 2> kernel1; 1345 kernel0.packet[0] = rhs.packetNoPadding(d, idx0); 1346 kernel0.packet[1] = rhs.packetNoPadding(d, idx1); 1347 kernel1.packet[0] = rhs.packetNoPadding(d, idx2); 1348 kernel1.packet[1] = rhs.packetNoPadding(d, idx3); 1349 ptranspose(kernel0); 1350 ptranspose(kernel1); 1351 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1352 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1353 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1354 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1355 block += 4 * packet_size; 1356 k += packet_size; 1357 } 1358 1359 // Go to the next column. 1360 continue; 1361 } 1362 1363 // If we can't squeeze reads, process rows one by one. 1364 for (Index r = start_row; r < max_row; ++r) { 1365 eigen_assert(k <= peeled_k); 1366 1367 const bool pad0 = pad_col0 || dm0.padRow(r); 1368 const bool pad1 = pad_col1 || dm1.padRow(r); 1369 const bool pad2 = pad_col2 || dm2.padRow(r); 1370 const bool pad3 = pad_col3 || dm3.padRow(r); 1371 1372 const Index idx0 = dm0.baseIndex(r, c); 1373 const Index idx1 = dm1.baseIndex(r, c); 1374 const Index idx2 = dm2.baseIndex(r, c); 1375 const Index idx3 = dm3.baseIndex(r, c); 1376 1377 const Index start_depth = ((c == start_col) && (r == start_row)) 1378 ? rhs.depthOffset() 1379 : 0; 1380 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1381 eigen_assert((max_depth - start_depth) % packet_size == 0); 1382 1383 for (Index d = start_depth; d < max_depth; d += packet_size) { 1384 eigen_assert(k < peeled_k); 1385 PacketBlock<Packet, 2> kernel0; 1386 PacketBlock<Packet, 2> kernel1; 1387 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1388 : rhs.packetNoPadding(d, idx0); 1389 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1390 : rhs.packetNoPadding(d, idx1); 1391 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1392 : rhs.packetNoPadding(d, idx2); 1393 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1394 : rhs.packetNoPadding(d, idx3); 1395 ptranspose(kernel0); 1396 ptranspose(kernel1); 1397 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1398 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1399 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1400 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1401 block += 4 * packet_size; 1402 k += packet_size; 1403 } 1404 } 1405 } 1406 1407 // The loop above should fill peeled_k elements. 1408 eigen_assert(peeled_k == k); 1409 1410 } else { 1411 // Packet can span multiple rows or columns, so we have to go 1412 // though the slower "standard" path. 1413 for (; k < peeled_k; k += packet_size) { 1414 PacketBlock<Packet, 2> kernel0; 1415 PacketBlock<Packet, 2> kernel1; 1416 kernel0.packet[0] = dm0.loadPacketStandard(k); 1417 kernel0.packet[1] = dm1.loadPacketStandard(k); 1418 kernel1.packet[0] = dm2.loadPacketStandard(k); 1419 kernel1.packet[1] = dm3.loadPacketStandard(k); 1420 ptranspose(kernel0); 1421 ptranspose(kernel1); 1422 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1423 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1424 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1425 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1426 block += 4 * packet_size; 1427 } 1428 } 1429 } 1430 1431 // Copy the remaining coefficients of the column block after the peeled_k. 1432 if (!non_standard_patches) { 1433 for (; k < depth; k++) { 1434 block[0] = dm0.loadCoeffStandard(k); 1435 block[1] = dm1.loadCoeffStandard(k); 1436 block[2] = dm2.loadCoeffStandard(k); 1437 block[3] = dm3.loadCoeffStandard(k); 1438 block += 4; 1439 } 1440 } else { 1441 for (; k < depth; k++) { 1442 block[0] = dm0(k); 1443 block[1] = dm1(k); 1444 block[2] = dm2(k); 1445 block[3] = dm3(k); 1446 block += 4; 1447 } 1448 } 1449 } 1450 1451 // Copy the remaining columns one at a time (nr==1). 1452 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1453 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1454 for (Index k = 0; k < depth; k++) { 1455 *block = dm0(k); 1456 block += 1; 1457 } 1458 } 1459 } 1460 }; 1461 1462 // Special case for non-vectorized types such as float16. 1463 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1464 typename Device, typename Scalar, typename Index, 1465 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1466 bool inner_dim_reordered, int Alignment, int nr> 1467 struct gemm_pack_rhs< 1468 Scalar, Index, 1469 TensorContractionSubMapper< 1470 Scalar, Index, Rhs, 1471 TensorEvaluator< 1472 const TensorReshapingOp< 1473 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1474 Device>, 1475 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1476 Alignment>, 1477 nr, ColMajor, false, false> { 1478 typedef TensorContractionSubMapper< 1479 Scalar, Index, Rhs, 1480 TensorEvaluator< 1481 const TensorReshapingOp< 1482 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1483 Device>, 1484 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1485 Alignment> 1486 SubMapper; 1487 typedef SubMapper DataMapper; 1488 1489 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1490 1491 EIGEN_DEVICE_FUNC 1492 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1493 Index depth, Index cols, Index stride = 0, 1494 Index offset = 0) const { 1495 eigen_assert(stride == 0); 1496 eigen_assert(offset == 0); 1497 1498 const Index packet_cols4 = (cols / 4) * 4; 1499 1500 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1501 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1502 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1503 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1504 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1505 1506 if (!rhs.nonStandardPatches()) { 1507 for (Index k = 0; k < depth; k++) { 1508 block[0] = dm0.loadCoeffStandard(k); 1509 block[1] = dm1.loadCoeffStandard(k); 1510 block[2] = dm2.loadCoeffStandard(k); 1511 block[3] = dm3.loadCoeffStandard(k); 1512 block += 4; 1513 } 1514 } else { 1515 for (Index k = 0; k < depth; k++) { 1516 block[0] = dm0(k); 1517 block[1] = dm1(k); 1518 block[2] = dm2(k); 1519 block[3] = dm3(k); 1520 block += 4; 1521 } 1522 } 1523 } 1524 1525 // Copy the remaining columns one at a time (nr==1). 1526 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1527 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1528 for (Index k = 0; k < depth; k++) { 1529 *block = dm0(k); 1530 block += 1; 1531 } 1532 } 1533 } 1534 }; 1535 } // end namespace internal 1536 1537 /** SpatialConvolution 1538 * \ingroup CXX11_NeuralNetworks_Module 1539 * 1540 * \brief Applies a 2D convolution over a multichannel input image. 1541 * 1542 * The input parameter is expected to be a tensor with a rank of 3 or more 1543 * (channels, height, width, and optionally others) 1544 * The kernel parameter is expected to be a 4D tensor (filters, channels, 1545 * kernel_height, kernel_width) 1546 * The input and the kernel must both be in col-major layout. The result will 1547 * also be in col-major layout. 1548 * 1549 * If col_in_stride, row_in_stride > 1, then applies convolution with holes 1550 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input 1551 * pixels. 1552 * 1553 * If padding_top, padding_bottom, padding_left, or padding_right is specified, 1554 * then those paddings will be used to pad the input, and padding_type must be 1555 * PADDING_VALID. 1556 * 1557 * The result can be assigned to a tensor of rank equal to the rank of the 1558 * input. The dimensions of the result will be filters, height, width (and 1559 * others if applicable). 1560 * 1561 * It is possible to swap the order of the width and height dimensions provided 1562 * that the same order is used in the input, the kernel, and the output. 1563 * 1564 * It is also possible to add an output kernel to the contraction, output 1565 * kernel is called by Eigen when it "finalizes" the block of an output tensor. 1566 * 1567 */ 1568 template <typename Input, typename Kernel, 1569 typename OutputKernel = const NoOpOutputKernel> 1570 EIGEN_DEVICE_FUNC 1571 EIGEN_ALWAYS_INLINE static const typename internal::conditional< 1572 internal::traits<Input>::Layout == ColMajor, 1573 TensorReshapingOp< 1574 const DSizes<typename internal::traits<Input>::Index, 1575 internal::traits<Input>::NumDimensions>, 1576 const TensorContractionOp< 1577 const array<IndexPair<typename internal::traits<Input>::Index>, 1578 1>, 1579 const TensorReshapingOp< 1580 const DSizes<typename internal::traits<Input>::Index, 2>, 1581 const Kernel>, 1582 const TensorReshapingOp< 1583 const DSizes<typename internal::traits<Input>::Index, 2>, 1584 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1585 const OutputKernel> >, 1586 TensorReshapingOp< 1587 const DSizes<typename internal::traits<Input>::Index, 1588 internal::traits<Input>::NumDimensions>, 1589 const TensorContractionOp< 1590 const array<IndexPair<typename internal::traits<Input>::Index>, 1591 1>, 1592 const TensorReshapingOp< 1593 const DSizes<typename internal::traits<Input>::Index, 2>, 1594 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1595 const TensorReshapingOp< 1596 const DSizes<typename internal::traits<Input>::Index, 2>, 1597 const Kernel>, 1598 const OutputKernel> > >::type 1599 SpatialConvolution(const Input& input, const Kernel& kernel, 1600 const Index row_stride = 1, const Index col_stride = 1, 1601 const PaddingType padding_type = PADDING_SAME, 1602 const Index row_in_stride = 1, 1603 const Index col_in_stride = 1, 1604 const OutputKernel& output_kernel = OutputKernel(), 1605 Index padding_top = 0, Index padding_bottom = 0, 1606 Index padding_left = 0, Index padding_right = 0) { 1607 typedef typename internal::traits<Input>::Index TensorIndex; 1608 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 1609 internal::traits<Input>::NumDimensions, 1610 internal::traits<Input>::Layout, TensorIndex> > 1611 in(input); 1612 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1613 internal::traits<Kernel>::NumDimensions, 1614 internal::traits<Kernel>::Layout, TensorIndex> > 1615 kern(kernel); 1616 1617 EIGEN_STATIC_ASSERT( 1618 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1619 YOU_MADE_A_PROGRAMMING_MISTAKE) 1620 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1621 1622 const int NumDims = internal::traits<Input>::NumDimensions; 1623 1624 // Number of filters to apply. This is the same as the output depth of the 1625 // result 1626 const TensorIndex kernelFilters = 1627 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; 1628 // Number of channels. This is the same as the input depth. 1629 const TensorIndex kernelChannels = 1630 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; 1631 const TensorIndex kernelRows = 1632 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; 1633 const TensorIndex kernelCols = 1634 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; 1635 1636 const Index kernelRowsEff = 1637 kernelRows + (kernelRows - 1) * (row_in_stride - 1); 1638 const Index kernelColsEff = 1639 kernelCols + (kernelCols - 1) * (col_in_stride - 1); 1640 1641 array<IndexPair<TensorIndex>, 1> contract_dims; 1642 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1643 1644 const TensorIndex InputRows = 1645 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1646 const TensorIndex InputCols = 1647 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1648 const bool padding_explicit = 1649 (padding_top || padding_bottom || padding_left || padding_right); 1650 1651 TensorIndex out_height; 1652 TensorIndex out_width; 1653 switch (padding_type) { 1654 case PADDING_VALID: { 1655 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom; 1656 const TensorIndex InputColsEff = InputCols + padding_left + padding_right; 1657 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride); 1658 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride); 1659 break; 1660 } 1661 case PADDING_SAME: { 1662 eigen_assert(!padding_explicit); 1663 out_height = divup(InputRows, row_stride); 1664 out_width = divup(InputCols, col_stride); 1665 break; 1666 } 1667 default: { 1668 // Initialize unused variables to avoid a compiler warning 1669 out_height = 0; 1670 out_width = 0; 1671 eigen_assert(false && "unexpected padding"); 1672 } 1673 } 1674 1675 // Molds the output of the patch extraction code into a 2d tensor: 1676 // - the first dimension (dims[0]): the patch values to be multiplied with the 1677 // kernels 1678 // - the second dimension (dims[1]): everything else 1679 DSizes<TensorIndex, 2> pre_contract_dims; 1680 if (isColMajor) { 1681 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; 1682 pre_contract_dims[1] = out_height * out_width; 1683 for (int i = 3; i < NumDims; ++i) { 1684 pre_contract_dims[1] *= in.dimension(i); 1685 } 1686 } else { 1687 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; 1688 pre_contract_dims[0] = out_height * out_width; 1689 for (int i = 0; i < NumDims - 3; ++i) { 1690 pre_contract_dims[0] *= in.dimension(i); 1691 } 1692 } 1693 1694 // Molds the output of the contraction into the shape expected by the used 1695 // (assuming this is ColMajor): 1696 // - 1st dim: kernel filters 1697 // - 2nd dim: output height 1698 // - 3rd dim: output width 1699 // - 4th dim and beyond: everything else including batch size 1700 DSizes<TensorIndex, NumDims> post_contract_dims; 1701 if (isColMajor) { 1702 post_contract_dims[0] = kernelFilters; 1703 post_contract_dims[1] = out_height; 1704 post_contract_dims[2] = out_width; 1705 for (int i = 3; i < NumDims; ++i) { 1706 post_contract_dims[i] = in.dimension(i); 1707 } 1708 } else { 1709 post_contract_dims[NumDims - 1] = kernelFilters; 1710 post_contract_dims[NumDims - 2] = out_height; 1711 post_contract_dims[NumDims - 3] = out_width; 1712 for (int i = 0; i < NumDims - 3; ++i) { 1713 post_contract_dims[i] = in.dimension(i); 1714 } 1715 } 1716 1717 DSizes<TensorIndex, 2> kernel_dims; 1718 if (isColMajor) { 1719 kernel_dims[0] = kernelFilters; 1720 kernel_dims[1] = kernelChannels * kernelRows * kernelCols; 1721 } else { 1722 kernel_dims[0] = kernelChannels * kernelRows * kernelCols; 1723 kernel_dims[1] = kernelFilters; 1724 } 1725 if (padding_explicit) { 1726 return choose( 1727 Cond<internal::traits<Input>::Layout == ColMajor>(), 1728 kernel.reshape(kernel_dims) 1729 .contract(input 1730 .extract_image_patches( 1731 kernelRows, kernelCols, row_stride, col_stride, 1732 row_in_stride, col_in_stride, 1733 /*row_inflate_stride=*/1, 1734 /*col_inflate_stride=*/1, padding_top, 1735 padding_bottom, padding_left, padding_right, 1736 /*padding_value=*/0) 1737 .reshape(pre_contract_dims), 1738 contract_dims, output_kernel) 1739 .reshape(post_contract_dims), 1740 input 1741 .extract_image_patches(kernelRows, kernelCols, row_stride, 1742 col_stride, row_in_stride, col_in_stride, 1743 /*row_inflate_stride=*/1, 1744 /*col_inflate_stride=*/1, padding_top, 1745 padding_bottom, padding_left, padding_right, 1746 /*padding_value=*/0) 1747 .reshape(pre_contract_dims) 1748 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) 1749 .reshape(post_contract_dims)); 1750 } else { 1751 return choose( 1752 Cond<internal::traits<Input>::Layout == ColMajor>(), 1753 kernel.reshape(kernel_dims) 1754 .contract(input 1755 .extract_image_patches( 1756 kernelRows, kernelCols, row_stride, col_stride, 1757 row_in_stride, col_in_stride, padding_type) 1758 .reshape(pre_contract_dims), 1759 contract_dims, output_kernel) 1760 .reshape(post_contract_dims), 1761 input 1762 .extract_image_patches(kernelRows, kernelCols, row_stride, 1763 col_stride, row_in_stride, col_in_stride, 1764 padding_type) 1765 .reshape(pre_contract_dims) 1766 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) 1767 .reshape(post_contract_dims)); 1768 } 1769 } 1770 1771 } // end namespace Eigen 1772 1773 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 1774