1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 namespace Eigen { 22 23 namespace internal { 24 25 // TODO: Consolidate this part of the code with the image patch extraction code 26 // since they are both very similar. 27 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, 28 typename ArgType, typename Device, typename Scalar_, typename Index, 29 typename nocontract_t, typename contract_t, int Side, int packet_size, 30 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 31 class TensorContractionInputMapper< 32 Scalar_, Index, Side, 33 TensorEvaluator< 34 const TensorReshapingOp<NewDimension, 35 const TensorImagePatchOp<Rows, Cols, ArgType> >, 36 Device>, 37 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 38 inner_dim_reordered, Alignment> { 39 public: 40 typedef Scalar_ Scalar; 41 typedef TensorContractionInputMapper< 42 Scalar, Index, Side, 43 TensorEvaluator< 44 const TensorReshapingOp< 45 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 46 Device>, 47 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 48 inner_dim_reordered, Alignment> 49 Self; 50 typedef TensorContractionSubMapper< 51 Scalar, Index, Side, 52 TensorEvaluator< 53 const TensorReshapingOp< 54 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 55 Device>, 56 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 57 inner_dim_reordered, Alignment> 58 SubMapper; 59 typedef SubMapper VectorMapper; 60 typedef SubMapper LinearMapper; 61 typedef typename packet_traits<Scalar>::type Packet; 62 63 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)64 TensorContractionInputMapper( 65 const TensorEvaluator< 66 const TensorReshapingOp< 67 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 68 Device>& tensor, 69 const nocontract_t&, const nocontract_t&, const contract_t&, 70 const contract_t&) 71 : m_impl(tensor.impl().impl()) { 72 Index patch_rows; 73 Index patch_depth; 74 if (internal::traits<ArgType>::Layout == ColMajor) { 75 patch_depth = tensor.impl().dimensions()[0]; 76 patch_rows = tensor.impl().dimensions()[1]; 77 m_patch_cols = tensor.impl().dimensions()[2]; 78 m_num_patches = tensor.impl().dimensions()[3]; 79 } else { 80 const int NumDims = tensor.impl().dimensions().size(); 81 patch_depth = tensor.impl().dimensions()[NumDims - 1]; 82 patch_rows = tensor.impl().dimensions()[NumDims - 2]; 83 m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; 84 m_num_patches = tensor.impl().dimensions()[NumDims - 4]; 85 } 86 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 87 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 88 89 m_colStride = patch_rows; 90 91 m_outputRows = tensor.impl().outputRows(); 92 m_row_strides = tensor.impl().userRowStride(); 93 m_col_strides = tensor.impl().userColStride(); 94 95 m_in_row_strides = tensor.impl().userInRowStride(); 96 m_in_col_strides = tensor.impl().userInColStride(); 97 98 if (internal::traits<ArgType>::Layout == ColMajor) { 99 m_inputRows = tensor.impl().impl().dimensions()[1]; 100 m_inputCols = tensor.impl().impl().dimensions()[2]; 101 } else { 102 const int NumDims = tensor.impl().impl().dimensions().size(); 103 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; 104 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; 105 } 106 107 m_rowInputStride = patch_depth; 108 m_colInputStride = patch_depth * m_inputRows; 109 m_patchInputStride = patch_depth * m_inputRows * m_inputCols; 110 111 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 112 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 113 114 m_fastInputRowStride = 115 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 116 m_fastInputColStride = 117 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 118 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 119 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 120 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 121 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); 122 } 123 124 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)125 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 126 : m_impl(base_mapper.m_impl) { 127 m_patch_cols = base_mapper.m_patch_cols; 128 m_num_patches = base_mapper.m_num_patches; 129 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 130 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 131 132 m_colStride = base_mapper.m_colStride; 133 134 m_rowInputStride = base_mapper.m_rowInputStride; 135 m_colInputStride = base_mapper.m_colInputStride; 136 m_patchInputStride = base_mapper.m_patchInputStride; 137 138 m_inputRows = base_mapper.m_inputRows; 139 m_inputCols = base_mapper.m_inputCols; 140 141 m_outputRows = base_mapper.m_outputRows; 142 m_row_strides = base_mapper.m_row_strides; 143 m_col_strides = base_mapper.m_col_strides; 144 145 m_in_row_strides = base_mapper.m_in_row_strides; 146 m_in_col_strides = base_mapper.m_in_col_strides; 147 148 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 149 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 150 151 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 152 m_fastInputColStride = base_mapper.m_fastInputColStride; 153 m_fastNumPatches = base_mapper.m_fastNumPatches; 154 m_fastColStride = base_mapper.m_fastColStride; 155 m_fastOutputRows = base_mapper.m_fastOutputRows; 156 m_fastDimZero = base_mapper.m_fastDimZero; 157 } 158 159 // If true, turns off some optimizations for loading packets since the image 160 // patches are "non-standard" such as there are non-trivial strides or 161 // inflations in the input. 162 EIGEN_DEVICE_FUNC nonStandardPatches()163 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 164 return m_in_row_strides != 1 || m_in_col_strides != 1 || 165 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 166 } 167 168 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)169 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 170 return SubMapper(*this, i, j); 171 } 172 173 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)174 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 175 return LinearMapper(*this, i, j); 176 } 177 178 EIGEN_DEVICE_FUNC operator()179 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 180 Index rowIndex, colIndex, otherIndex; 181 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 182 return loadCoeff(row, rowIndex, colIndex, otherIndex); 183 } 184 185 // Load the coefficient at the patchIndex location instead of the usual 186 // m_rowIndex, 187 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 188 // EIGEN_DEVICE_FUNC 189 EIGEN_DEVICE_FUNC operator()190 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 191 Index rowIndex, colIndex, otherIndex; 192 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 193 return loadCoeff(row, rowIndex, colIndex, otherIndex); 194 } 195 196 EIGEN_DEVICE_FUNC loadPacket(Index row)197 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 198 Index rowIndex, colIndex, otherIndex; 199 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 200 return loadPacket(row, rowIndex, colIndex, otherIndex); 201 } 202 203 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 204 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 205 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)206 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 207 Index rowIndex, colIndex, otherIndex; 208 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 209 return loadPacket(row, rowIndex, colIndex, otherIndex); 210 } 211 212 EIGEN_DEVICE_FUNC impl()213 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 214 return m_impl; 215 } 216 217 EIGEN_DEVICE_FUNC patchDepth()218 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } 219 EIGEN_DEVICE_FUNC patchRows()220 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } 221 EIGEN_DEVICE_FUNC patchCols()222 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 223 224 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)225 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 226 const Index baseIndex) const { 227 const Index inputIndex = depth + baseIndex; 228 return m_impl.template packet<Unaligned>(inputIndex); 229 } 230 231 private: 232 friend class TensorContractionSubMapper< 233 Scalar, Index, Side, 234 TensorEvaluator< 235 const TensorReshapingOp< 236 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 237 Device>, 238 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 239 inner_dim_reordered, Alignment>; 240 241 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)242 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, 243 Index colIndex, Index otherIndex) const { 244 // Find the offset of the element wrt the location of the first element. 245 const Index patchOffset = patchId / m_fastDimZero; 246 247 const Index colOffset = patchOffset / m_fastColStride; 248 const Index inputCol = colIndex + colOffset * m_in_col_strides; 249 const Index origInputCol = 250 (m_patch_col_inflate_strides == 1) 251 ? inputCol 252 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 253 const Index rowOffset = patchOffset - colOffset * m_colStride; 254 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 255 const Index origInputRow = 256 (m_patch_row_inflate_strides == 1) 257 ? inputRow 258 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 259 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || 260 origInputRow >= m_inputRows || 261 (inputCol != origInputCol * m_patch_col_inflate_strides) || 262 (inputRow != origInputRow * m_patch_row_inflate_strides)) { 263 return Scalar(0); 264 } 265 const Index depth = patchId - patchOffset * patchDepth(); 266 const Index inputIndex = depth + origInputRow * m_rowInputStride + 267 origInputCol * m_colInputStride + otherIndex; 268 return m_impl.coeff(inputIndex); 269 } 270 271 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)272 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, 273 Index colIndex, 274 Index otherIndex) const { 275 eigen_assert(!nonStandardPatches()); 276 277 // Find the offset of the element wrt the location of the first element. 278 const Index patchOffset = patchId / m_fastDimZero; 279 280 const Index colOffset = patchOffset / m_fastColStride; 281 const Index inputCol = colIndex + colOffset; 282 const Index rowOffset = patchOffset - colOffset * m_colStride; 283 const Index inputRow = rowIndex + rowOffset; 284 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 285 inputRow >= m_inputRows) { 286 return Scalar(0); 287 } 288 const Index depth = patchId - patchOffset * patchDepth(); 289 const Index inputIndex = depth + inputRow * m_rowInputStride + 290 inputCol * m_colInputStride + otherIndex; 291 return m_impl.coeff(inputIndex); 292 } 293 294 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)295 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, 296 Index colIndex, 297 Index otherIndex) const { 298 const Index packetSize = internal::unpacket_traits<Packet>::size; 299 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 300 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 301 302 if (nonStandardPatches()) { 303 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 304 } 305 return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex); 306 } 307 308 EIGEN_DEVICE_FUNC loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)309 EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex, 310 Index colIndex, 311 Index otherIndex) const { 312 const Index packetSize = internal::unpacket_traits<Packet>::size; 313 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 314 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 315 316 eigen_assert(!nonStandardPatches()); 317 318 if ((patchDepth() % packetSize) == 0) { 319 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 320 } else { 321 const Index patchOffsets[2] = { 322 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 323 324 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 325 patchOffsets[1] / m_fastColStride}; 326 327 const Index inputCols[2] = {colIndex + colOffsets[0], 328 colIndex + colOffsets[1]}; 329 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 330 // all zeros 331 return internal::pset1<Packet>(Scalar(0)); 332 } 333 334 if (inputCols[0] == inputCols[1]) { 335 const Index rowOffsets[2] = { 336 patchOffsets[0] - colOffsets[0] * m_colStride, 337 patchOffsets[1] - colOffsets[1] * m_colStride}; 338 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 339 const Index inputRows[2] = {rowIndex + rowOffsets[0], 340 rowIndex + rowOffsets[1]}; 341 342 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 343 // all zeros 344 return internal::pset1<Packet>(Scalar(0)); 345 } 346 347 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 348 // no padding 349 const Index depth = patchId - patchOffsets[0] * patchDepth(); 350 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 351 inputCols[0] * m_colInputStride + otherIndex; 352 return m_impl.template packet<Unaligned>(inputIndex); 353 } 354 } 355 } 356 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 357 } 358 359 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)360 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, 361 Index colIndex, 362 Index otherIndex) const { 363 const Index packetSize = internal::unpacket_traits<Packet>::size; 364 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 365 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 366 367 eigen_assert(!nonStandardPatches()); 368 eigen_assert((patchDepth() % packetSize) == 0); 369 // Find the offset of the element wrt the location of the first element. 370 const Index patchOffset = patchId / m_fastDimZero; 371 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 372 373 const Index colOffset = patchOffset / m_fastColStride; 374 const Index inputCol = colIndex + colOffset; 375 const Index rowOffset = patchOffset - colOffset * m_colStride; 376 const Index inputRow = rowIndex + rowOffset; 377 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || 378 inputRow >= m_inputRows) { 379 // all zeros 380 return internal::pset1<Packet>(Scalar(0)); 381 } 382 // no padding 383 const Index depth = patchId - patchOffset * patchDepth(); 384 const Index inputIndex = depth + inputRow * m_rowInputStride + 385 inputCol * m_colInputStride + otherIndex; 386 return m_impl.template packet<Unaligned>(inputIndex); 387 } 388 packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)389 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( 390 Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { 391 const int packetSize = internal::unpacket_traits<Packet>::size; 392 EIGEN_ALIGN_MAX 393 typename internal::remove_const<Scalar>::type values[packetSize]; 394 for (int i = 0; i < packetSize; ++i) { 395 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); 396 } 397 Packet rslt = internal::pload<Packet>(values); 398 return rslt; 399 } 400 computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)401 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 402 Index patchIndex, Index& rowIndex, Index& colIndex, 403 Index& otherIndex) const { 404 const int NumInputDims = array_size< 405 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 406 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; 407 const Index patch2DIndex = (NumInputDims == 3) 408 ? patchIndex 409 : (patchIndex - otherIndex * m_num_patches); 410 otherIndex *= m_patchInputStride; 411 colIndex = patch2DIndex / m_fastOutputRows; 412 rowIndex = patch2DIndex - colIndex * m_outputRows; 413 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 414 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 415 } 416 417 Index m_patch_cols; // number of colums in the patch 418 Index m_num_patches; // number of patches to extract. 419 Index m_patch_row_inflate_strides; // the strides for row inflation in the 420 // image patch 421 Index m_patch_col_inflate_strides; // the strides for col inflation in the 422 // image patch 423 // Fast representation of inflation strides. 424 internal::TensorIntDivisor<Index> m_fastInputRowStride; 425 internal::TensorIntDivisor<Index> m_fastInputColStride; 426 427 Index m_otherStride; 428 Index m_colStride; 429 internal::TensorIntDivisor<Index> m_fastNumPatches; 430 internal::TensorIntDivisor<Index> m_fastColStride; 431 432 Index m_rowInputStride; // row stride in the input tensor 433 Index m_colInputStride; // col stride in the input tensor 434 Index m_patchInputStride; // patch stride in the input tensor 435 436 Index m_inputRows; // Number of rows in the input tensor 437 Index m_inputCols; // Number of cols in the input tensor 438 439 Index m_outputRows; // Number of patch rows 440 441 Index m_row_strides; // User specified row stride 442 Index m_col_strides; // User specified col stride 443 444 Index m_in_row_strides; // User specified input row stride 445 Index m_in_col_strides; // User specified input col stride 446 447 Index m_rowPaddingTop; // Row padding 448 Index m_colPaddingLeft; // Column padding 449 450 internal::TensorIntDivisor<Index> m_fastOutputRows; 451 internal::TensorIntDivisor<Index> m_fastDimZero; 452 453 const TensorEvaluator<ArgType, Device> m_impl; 454 }; 455 456 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, 457 typename ArgType, typename Device, typename Scalar, typename Index, 458 typename nocontract_t, typename contract_t, int Side, int packet_size, 459 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 460 class TensorContractionSubMapper< 461 Scalar, Index, Side, 462 TensorEvaluator< 463 const TensorReshapingOp<NewDimension, 464 const TensorImagePatchOp<Rows, Cols, ArgType> >, 465 Device>, 466 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 467 inner_dim_reordered, Alignment> { 468 public: 469 typedef typename packet_traits<Scalar>::type Packet; 470 typedef typename packet_traits<Scalar>::half HalfPacket; 471 472 typedef TensorContractionInputMapper< 473 Scalar, Index, Side, 474 TensorEvaluator< 475 const TensorReshapingOp< 476 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 477 Device>, 478 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 479 inner_dim_reordered, Alignment> 480 ParentMapper; 481 typedef TensorContractionSubMapper< 482 Scalar, Index, Side, 483 TensorEvaluator< 484 const TensorReshapingOp< 485 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 486 Device>, 487 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 488 inner_dim_reordered, Alignment> 489 Self; 490 typedef Self LinearMapper; 491 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)492 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 493 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 494 : m_base_mapper(base_mapper), 495 m_depth_offset(vert_offset), 496 m_col_offset(horiz_offset) { 497 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 498 m_otherIndex); 499 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)500 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 501 const Self& base_mapper, Index vert_offset, Index horiz_offset) 502 : m_base_mapper(base_mapper.m_base_mapper), 503 m_depth_offset(vert_offset + base_mapper.m_depth_offset), 504 m_col_offset(horiz_offset + base_mapper.m_col_offset) { 505 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 506 m_otherIndex); 507 } operator()508 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 509 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, 510 m_otherIndex); 511 } operator()512 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 513 Index j) const { 514 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 515 } 516 loadPacket(Index i)517 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 518 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, 519 m_otherIndex); 520 } loadPacket(Index i,Index j)521 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 522 Index j) const { 523 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 524 j + m_col_offset); 525 } 526 527 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)528 loadCoeffStandard(Index i) const { 529 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, 530 m_colIndex, m_otherIndex); 531 } 532 loadPacketFast(Index i)533 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 534 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, 535 m_colIndex, m_otherIndex); 536 } 537 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)538 loadPacketStandard(Index i) const { 539 return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex, 540 m_colIndex, m_otherIndex); 541 } 542 template <typename Packet> aligned(Index)543 EIGEN_DEVICE_FUNC bool aligned(Index) const { 544 return false; 545 } 546 547 EIGEN_DEVICE_FUNC nonStandardPatches()548 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 549 return m_base_mapper.nonStandardPatches(); 550 } 551 552 EIGEN_DEVICE_FUNC patchDepth()553 EIGEN_ALWAYS_INLINE Index patchDepth() const { 554 return m_base_mapper.m_rowInputStride; 555 } 556 EIGEN_DEVICE_FUNC patchRows()557 EIGEN_ALWAYS_INLINE Index patchRows() const { 558 return m_base_mapper.m_colStride; 559 } 560 EIGEN_DEVICE_FUNC patchCols()561 EIGEN_ALWAYS_INLINE Index patchCols() const { 562 return m_base_mapper.m_patch_cols; 563 } 564 565 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)566 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 567 const Index baseIndex) const { 568 const Index inputIndex = depth + baseIndex; 569 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 570 } 571 572 EIGEN_DEVICE_FUNC padRow(const Index row)573 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 574 const Index r = m_rowIndex + row; 575 return r < 0 || r >= m_base_mapper.m_inputRows; 576 } 577 EIGEN_DEVICE_FUNC padCol(const Index col)578 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 579 const Index c = m_colIndex + col; 580 return c < 0 || c >= m_base_mapper.m_inputCols; 581 } 582 EIGEN_DEVICE_FUNC baseIndex(const Index row,const Index col)583 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { 584 const Index r = m_rowIndex + row; 585 const Index c = m_colIndex + col; 586 return r * m_base_mapper.m_rowInputStride + 587 c * m_base_mapper.m_colInputStride + m_otherIndex; 588 } 589 590 EIGEN_DEVICE_FUNC rowOffset()591 EIGEN_ALWAYS_INLINE Index rowOffset() const { 592 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 593 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 594 return patchOffset - colOffset * m_base_mapper.m_colStride; 595 } 596 597 EIGEN_DEVICE_FUNC colOffset()598 EIGEN_ALWAYS_INLINE Index colOffset() const { 599 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 600 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 601 return colOffset; 602 } 603 604 EIGEN_DEVICE_FUNC depthOffset()605 EIGEN_ALWAYS_INLINE Index depthOffset() const { 606 const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth(); 607 return patchOffset; 608 } 609 610 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)611 getLinearMapper(Index i, Index j) const { 612 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 613 } 614 615 private: 616 const ParentMapper& m_base_mapper; // that was a reference before 617 Index m_depth_offset; // First row in the input matrix 618 Index m_col_offset; // First col in the input matrix 619 620 Index m_rowIndex; // precomputed row index corresponding to the col offset 621 Index m_colIndex; // precomputed col index corresponding to the col offset 622 Index 623 m_otherIndex; // precomputed other index corresponding to the col offset 624 }; 625 626 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, 627 typename ArgType, typename Device, typename Scalar, typename Index, 628 typename nocontract_t, typename contract_t, int packet_size, 629 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 630 int nr> 631 struct gemm_pack_rhs< 632 Scalar, Index, 633 TensorContractionSubMapper< 634 Scalar, Index, Rhs, 635 TensorEvaluator< 636 const TensorReshapingOp< 637 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 638 Device>, 639 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 640 inner_dim_reordered, Alignment>, 641 nr, ColMajor, false, false> { 642 typedef TensorContractionSubMapper< 643 Scalar, Index, Rhs, 644 TensorEvaluator< 645 const TensorReshapingOp< 646 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 647 Device>, 648 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 649 inner_dim_reordered, Alignment> 650 SubMapper; 651 typedef SubMapper DataMapper; 652 653 EIGEN_DEVICE_FUNC 654 static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } 655 656 EIGEN_DEVICE_FUNC 657 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 658 Index depth, Index cols, Index stride = 0, 659 Index offset = 0) const { 660 eigen_assert(stride == 0); 661 eigen_assert(offset == 0); 662 663 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 664 typedef typename packet_traits<Scalar>::type Packet; 665 666 const Index packet_cols4 = (cols / 4) * 4; 667 const Index peeled_k = (depth / packet_size) * packet_size; 668 const bool non_standard_patches = rhs.nonStandardPatches(); 669 670 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 671 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 672 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 673 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 674 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 675 676 Index k = 0; 677 if ((packet_size % 4) == 0 && !non_standard_patches) { 678 const Index patch_depth = rhs.patchDepth(); 679 if ((patch_depth % packet_size) == 0) { 680 const Index patch_cols = rhs.patchCols(); 681 const Index patch_rows = rhs.patchRows(); 682 683 const Index startCol = rhs.colOffset(); 684 const Index max_cols = std::min<Index>( 685 ceil_div(peeled_k, patch_rows * patch_depth) + startCol, 686 patch_cols); 687 688 for (Index c = startCol; c < max_cols; ++c) { 689 eigen_assert(k < peeled_k); 690 const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; 691 const Index max_rows = std::min<Index>( 692 ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) + 693 startRow, 694 patch_rows); 695 696 const bool pad_col0 = dm0.padCol(c); 697 const bool pad_col1 = dm1.padCol(c); 698 const bool pad_col2 = dm2.padCol(c); 699 const bool pad_col3 = dm3.padCol(c); 700 for (Index r = startRow; r < max_rows; ++r) { 701 eigen_assert(k < peeled_k); 702 const bool pad0 = pad_col0 || dm0.padRow(r); 703 const bool pad1 = pad_col1 || dm1.padRow(r); 704 const bool pad2 = pad_col2 || dm2.padRow(r); 705 const bool pad3 = pad_col3 || dm3.padRow(r); 706 707 const Index idx0 = dm0.baseIndex(r, c); 708 const Index idx1 = dm1.baseIndex(r, c); 709 const Index idx2 = dm2.baseIndex(r, c); 710 const Index idx3 = dm3.baseIndex(r, c); 711 712 const Index startDepth = 713 ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0; 714 const Index max_depth = 715 std::min<Index>(peeled_k - c * patch_rows * patch_depth - 716 r * patch_depth + startDepth, 717 patch_depth); 718 eigen_assert((max_depth - startDepth) % packet_size == 0); 719 for (Index d = startDepth; d < max_depth; d += packet_size) { 720 eigen_assert(k < peeled_k); 721 PacketBlock<Packet, 4> kernel; 722 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 723 : rhs.packetNoPadding(d, idx0); 724 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 725 : rhs.packetNoPadding(d, idx1); 726 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 727 : rhs.packetNoPadding(d, idx2); 728 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 729 : rhs.packetNoPadding(d, idx3); 730 ptranspose(kernel); 731 pstoreu(block + 0 * packet_size, kernel.packet[0]); 732 pstoreu(block + 1 * packet_size, kernel.packet[1]); 733 pstoreu(block + 2 * packet_size, kernel.packet[2]); 734 pstoreu(block + 3 * packet_size, kernel.packet[3]); 735 block += 4 * packet_size; 736 k += packet_size; 737 } 738 } 739 } 740 741 for (; k < peeled_k; k += packet_size) { 742 PacketBlock<Packet, 4> kernel; 743 kernel.packet[0] = dm0.loadPacketFast(k); 744 kernel.packet[1] = dm1.loadPacketFast(k); 745 kernel.packet[2] = dm2.loadPacketFast(k); 746 kernel.packet[3] = dm3.loadPacketFast(k); 747 ptranspose(kernel); 748 pstoreu(block + 0 * packet_size, kernel.packet[0]); 749 pstoreu(block + 1 * packet_size, kernel.packet[1]); 750 pstoreu(block + 2 * packet_size, kernel.packet[2]); 751 pstoreu(block + 3 * packet_size, kernel.packet[3]); 752 block += 4 * packet_size; 753 } 754 } else { 755 for (; k < peeled_k; k += packet_size) { 756 PacketBlock<Packet, 4> kernel; 757 kernel.packet[0] = dm0.loadPacketStandard(k); 758 kernel.packet[1] = dm1.loadPacketStandard(k); 759 kernel.packet[2] = dm2.loadPacketStandard(k); 760 kernel.packet[3] = dm3.loadPacketStandard(k); 761 ptranspose(kernel); 762 pstoreu(block + 0 * packet_size, kernel.packet[0]); 763 pstoreu(block + 1 * packet_size, kernel.packet[1]); 764 pstoreu(block + 2 * packet_size, kernel.packet[2]); 765 pstoreu(block + 3 * packet_size, kernel.packet[3]); 766 block += 4 * packet_size; 767 } 768 } 769 } 770 if (!rhs.nonStandardPatches()) { 771 for (; k < depth; k++) { 772 block[0] = dm0.loadCoeffStandard(k); 773 block[1] = dm1.loadCoeffStandard(k); 774 block[2] = dm2.loadCoeffStandard(k); 775 block[3] = dm3.loadCoeffStandard(k); 776 block += 4; 777 } 778 } else { 779 for (; k < depth; k++) { 780 block[0] = dm0(k); 781 block[1] = dm1(k); 782 block[2] = dm2(k); 783 block[3] = dm3(k); 784 block += 4; 785 } 786 } 787 } 788 789 // copy the remaining columns one at a time (nr==1) 790 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 791 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 792 for (Index k = 0; k < depth; k++) { 793 *block = dm0(k); 794 block += 1; 795 } 796 } 797 } 798 }; 799 800 // Special case for non-vectorized types such as float16. 801 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, 802 typename ArgType, typename Device, typename Scalar, typename Index, 803 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 804 bool inner_dim_reordered, int Alignment, int nr> 805 struct gemm_pack_rhs< 806 Scalar, Index, 807 TensorContractionSubMapper< 808 Scalar, Index, Rhs, 809 TensorEvaluator< 810 const TensorReshapingOp< 811 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 812 Device>, 813 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 814 Alignment>, 815 nr, ColMajor, false, false> { 816 typedef TensorContractionSubMapper< 817 Scalar, Index, Rhs, 818 TensorEvaluator< 819 const TensorReshapingOp< 820 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 821 Device>, 822 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 823 Alignment> 824 SubMapper; 825 typedef SubMapper DataMapper; 826 827 EIGEN_DEVICE_FUNC 828 static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } 829 830 EIGEN_DEVICE_FUNC 831 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 832 Index depth, Index cols, Index stride = 0, 833 Index offset = 0) const { 834 eigen_assert(stride == 0); 835 eigen_assert(offset == 0); 836 837 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 838 839 const Index packet_cols4 = (cols / 4) * 4; 840 841 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 842 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 843 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 844 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 845 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 846 847 if (!rhs.nonStandardPatches()) { 848 for (Index k = 0; k < depth; k++) { 849 block[0] = dm0.loadCoeffStandard(k); 850 block[1] = dm1.loadCoeffStandard(k); 851 block[2] = dm2.loadCoeffStandard(k); 852 block[3] = dm3.loadCoeffStandard(k); 853 block += 4; 854 } 855 } else { 856 for (Index k = 0; k < depth; k++) { 857 block[0] = dm0(k); 858 block[1] = dm1(k); 859 block[2] = dm2(k); 860 block[3] = dm3(k); 861 block += 4; 862 } 863 } 864 } 865 866 // copy the remaining columns one at a time (nr==1) 867 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 868 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 869 for (Index k = 0; k < depth; k++) { 870 *block = dm0(k); 871 block += 1; 872 } 873 } 874 } 875 }; 876 877 } // end namespace internal 878 879 /** SpatialConvolution 880 * \ingroup CXX11_NeuralNetworks_Module 881 * 882 * \brief Applies a 2D convolution over a multichannel input image. 883 * 884 * The input parameter is expected to be a tensor with a rank of 3 or more 885 * (channels, height, width, and optionally others) 886 * The kernel parameter is expected to be a 4D tensor (filters, channels, 887 * kernel_height, kernel_width) 888 * The input and the kernel must both be in col-major layout. The result will 889 * also be in col-major layout. 890 * 891 * If col_in_stride, row_in_stride > 1, then applies convolution with holes 892 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input 893 * pixels. 894 * 895 * The result can be assigned to a tensor of rank equal to the rank of the 896 * input. The dimensions of the result will be filters, height, width (and 897 * others if applicable). 898 * 899 * It is possible to swap the order of the width and height dimensions provided 900 * that the same order is used in the input, the kernel, and the output. 901 * 902 */ 903 template <typename Input, typename Kernel> 904 EIGEN_DEVICE_FUNC 905 EIGEN_ALWAYS_INLINE static const typename internal::conditional< 906 internal::traits<Input>::Layout == ColMajor, 907 TensorReshapingOp< 908 const DSizes<typename internal::traits<Input>::Index, 909 internal::traits<Input>::NumDimensions>, 910 const TensorContractionOp< 911 const array<IndexPair<typename internal::traits<Input>::Index>, 912 1>, 913 const TensorReshapingOp< 914 const DSizes<typename internal::traits<Input>::Index, 2>, 915 const Kernel>, 916 const TensorReshapingOp< 917 const DSizes<typename internal::traits<Input>::Index, 2>, 918 const TensorImagePatchOp<Dynamic, Dynamic, 919 const Input> > > >, 920 TensorReshapingOp< 921 const DSizes<typename internal::traits<Input>::Index, 922 internal::traits<Input>::NumDimensions>, 923 const TensorContractionOp< 924 const array<IndexPair<typename internal::traits<Input>::Index>, 925 1>, 926 const TensorReshapingOp< 927 const DSizes<typename internal::traits<Input>::Index, 2>, 928 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 929 const TensorReshapingOp< 930 const DSizes<typename internal::traits<Input>::Index, 2>, 931 const Kernel> > > >::type 932 SpatialConvolution(const Input& input, const Kernel& kernel, 933 const DenseIndex row_stride = 1, 934 const DenseIndex col_stride = 1, 935 const PaddingType padding_type = PADDING_SAME, 936 const DenseIndex row_in_stride = 1, 937 const DenseIndex col_in_stride = 1) { 938 typedef typename internal::traits<Input>::Index TensorIndex; 939 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 940 internal::traits<Input>::NumDimensions, 941 internal::traits<Input>::Layout, TensorIndex> > 942 in(input); 943 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 944 internal::traits<Kernel>::NumDimensions, 945 internal::traits<Kernel>::Layout, TensorIndex> > 946 kern(kernel); 947 948 EIGEN_STATIC_ASSERT( 949 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 950 YOU_MADE_A_PROGRAMMING_MISTAKE); 951 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 952 953 const int NumDims = internal::traits<Input>::NumDimensions; 954 955 // Number of filters to apply. This is the same as the output depth of the 956 // result 957 const TensorIndex kernelFilters = 958 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; 959 // Number of channels. This is the same as the input depth. 960 const TensorIndex kernelChannels = 961 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; 962 const TensorIndex kernelRows = 963 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; 964 const TensorIndex kernelCols = 965 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; 966 967 const DenseIndex kernelRowsEff = 968 kernelRows + (kernelRows - 1) * (row_in_stride - 1); 969 const DenseIndex kernelColsEff = 970 kernelCols + (kernelCols - 1) * (col_in_stride - 1); 971 972 array<IndexPair<TensorIndex>, 1> contract_dims; 973 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 974 975 const TensorIndex InputRows = 976 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 977 const TensorIndex InputCols = 978 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 979 980 TensorIndex out_height; 981 TensorIndex out_width; 982 switch (padding_type) { 983 case PADDING_VALID: 984 out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / 985 static_cast<float>(row_stride)); 986 out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / 987 static_cast<float>(col_stride)); 988 break; 989 case PADDING_SAME: 990 out_height = numext::ceil(InputRows / static_cast<float>(row_stride)); 991 out_width = numext::ceil(InputCols / static_cast<float>(col_stride)); 992 break; 993 default: 994 // Initialize unused variables to avoid a compiler warning 995 out_height = 0; 996 out_width = 0; 997 eigen_assert(false && "unexpected padding"); 998 } 999 1000 // Molds the output of the patch extraction code into a 2d tensor: 1001 // - the first dimension (dims[0]): the patch values to be multiplied with the 1002 // kernels 1003 // - the second dimension (dims[1]): everything else 1004 DSizes<TensorIndex, 2> pre_contract_dims; 1005 if (isColMajor) { 1006 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; 1007 pre_contract_dims[1] = out_height * out_width; 1008 for (int i = 3; i < NumDims; ++i) { 1009 pre_contract_dims[1] *= in.dimension(i); 1010 } 1011 } else { 1012 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; 1013 pre_contract_dims[0] = out_height * out_width; 1014 for (int i = 0; i < NumDims - 3; ++i) { 1015 pre_contract_dims[0] *= in.dimension(i); 1016 } 1017 } 1018 1019 // Molds the output of the contraction into the shape expected by the used 1020 // (assuming this is ColMajor): 1021 // - 1st dim: kernel filters 1022 // - 2nd dim: output height 1023 // - 3rd dim: output width 1024 // - 4th dim and beyond: everything else including batch size 1025 DSizes<TensorIndex, NumDims> post_contract_dims; 1026 if (isColMajor) { 1027 post_contract_dims[0] = kernelFilters; 1028 post_contract_dims[1] = out_height; 1029 post_contract_dims[2] = out_width; 1030 for (int i = 3; i < NumDims; ++i) { 1031 post_contract_dims[i] = in.dimension(i); 1032 } 1033 } else { 1034 post_contract_dims[NumDims - 1] = kernelFilters; 1035 post_contract_dims[NumDims - 2] = out_height; 1036 post_contract_dims[NumDims - 3] = out_width; 1037 for (int i = 0; i < NumDims - 3; ++i) { 1038 post_contract_dims[i] = in.dimension(i); 1039 } 1040 } 1041 1042 DSizes<TensorIndex, 2> kernel_dims; 1043 if (isColMajor) { 1044 kernel_dims[0] = kernelFilters; 1045 kernel_dims[1] = kernelChannels * kernelRows * kernelCols; 1046 } else { 1047 kernel_dims[0] = kernelChannels * kernelRows * kernelCols; 1048 kernel_dims[1] = kernelFilters; 1049 } 1050 // TODO(yangke): choose() is defined in TensorContraction.h -- consider 1051 // moving it to somewhere more "common". 1052 return choose( 1053 Cond<internal::traits<Input>::Layout == ColMajor>(), 1054 kernel.reshape(kernel_dims) 1055 .contract(input 1056 .extract_image_patches( 1057 kernelRows, kernelCols, row_stride, col_stride, 1058 row_in_stride, col_in_stride, padding_type) 1059 .reshape(pre_contract_dims), 1060 contract_dims) 1061 .reshape(post_contract_dims), 1062 input 1063 .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, 1064 row_in_stride, col_in_stride, padding_type) 1065 .reshape(pre_contract_dims) 1066 .contract(kernel.reshape(kernel_dims), contract_dims) 1067 .reshape(post_contract_dims)); 1068 } 1069 1070 } // end namespace Eigen 1071 1072 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_ 1073