1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/kernels/eigen_volume_patch.h" 21 22 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 23 #include "tensorflow/core/kernels/eigen_contraction_kernel.h" 24 #endif 25 26 #include "tensorflow/core/kernels/eigen_convolution_helpers.h" 27 28 namespace Eigen { 29 30 namespace internal { 31 32 // WARNING: Most of the code here implicitly assumes that the matrix is in 33 // ColMajor layout. This is guaranteed by the tensor contraction (see 34 // TensorContraction.h). 35 // 36 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 37 // We don't want to actually extract volume patches and reshape the result into 38 // a matrix (this involves allocating huge extra memory), so the patch 39 // extraction and reshape operations are implicit. 40 // 41 // TensorContractionInputMapper takes a matrix index and returns the coefficient 42 // (or the packet) of the "virtual tensor", that would be at that index if we 43 // were to actually reshape the result of patch extraction. 44 // 45 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 46 // at the given vertical and horizontal offsets. 47 // 48 // "Virtual matrix" dimensions: 49 // *0: kernelChannels * kernelPlanes * kernelRows * kernelCols 50 // 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...) 51 // 52 // *) extracted patches are continuous in memory (innermost dimension assuming 53 // col major layout) 54 // 55 // With this dimensions: 56 // row - offset within a single patch (in code: patchId) 57 // col - index of the extracted patch (in code: patchIndex) 58 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 59 // 60 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 61 typename ArgType, typename Device, typename Scalar_, typename Index, 62 typename nocontract_t, typename contract_t, int Side, int packet_size, 63 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 64 class TensorContractionInputMapper< 65 Scalar_, Index, Side, 66 TensorEvaluator<const TensorReshapingOp<NewDimension, 67 const TensorVolumePatchOp< 68 Planes, Rows, Cols, ArgType> >, 69 Device>, 70 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 71 inner_dim_reordered, Alignment> { 72 public: 73 typedef Scalar_ Scalar; 74 typedef TensorContractionInputMapper< 75 Scalar, Index, Side, 76 TensorEvaluator<const TensorReshapingOp< 77 NewDimension, const TensorVolumePatchOp< 78 Planes, Rows, Cols, ArgType> >, 79 Device>, 80 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 81 inner_dim_reordered, Alignment> 82 Self; 83 typedef TensorContractionSubMapper< 84 Scalar, Index, Side, 85 TensorEvaluator<const TensorReshapingOp< 86 NewDimension, const TensorVolumePatchOp< 87 Planes, Rows, Cols, ArgType> >, 88 Device>, 89 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 90 inner_dim_reordered, Alignment> 91 SubMapper; 92 typedef SubMapper VectorMapper; 93 typedef SubMapper LinearMapper; 94 typedef typename packet_traits<Scalar>::type Packet; 95 96 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)97 TensorContractionInputMapper( 98 const TensorEvaluator< 99 const TensorReshapingOp< 100 NewDimension, 101 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >, 102 Device>& tensor, 103 const nocontract_t&, const nocontract_t&, const contract_t&, 104 const contract_t&) 105 : m_impl(tensor.impl().impl()) { 106 if (internal::traits<ArgType>::Layout == ColMajor) { 107 m_patch_depth = tensor.impl().dimensions()[0]; 108 m_patch_planes = tensor.impl().dimensions()[1]; 109 m_patch_rows = tensor.impl().dimensions()[2]; 110 m_patch_cols = tensor.impl().dimensions()[3]; 111 m_num_patches = tensor.impl().dimensions()[4]; 112 } else { 113 const int NumDims = tensor.impl().dimensions().size(); 114 m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; 115 m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; 116 m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; 117 m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; 118 m_num_patches = tensor.impl().dimensions()[NumDims - 5]; 119 } 120 121 // Strides for navigating through the single patch. 122 m_patch_plane_stride = m_patch_depth; 123 m_patch_row_stride = m_patch_planes * m_patch_plane_stride; 124 m_patch_col_stride = m_patch_rows * m_patch_row_stride; 125 126 // Strides for the output tensor. 127 // IMPORTANT: These strides are used to locate an element in a patch at a 128 // depth zero (channel), which is not quite the same as "traditional" 129 // stride. 130 m_rowStride = m_patch_planes; 131 m_colStride = m_patch_rows * m_rowStride; 132 m_patchStride = m_colStride * m_patch_cols * m_patch_depth; 133 m_otherStride = m_patchStride * m_num_patches; 134 135 m_outputPlanes = tensor.impl().outputPlanes(); 136 m_outputRows = tensor.impl().outputRows(); 137 m_outputCols = tensor.impl().outputCols(); 138 139 m_outputPlanesRows = m_outputPlanes * m_outputRows; 140 141 m_plane_strides = tensor.impl().userPlaneStride(); 142 m_row_strides = tensor.impl().userRowStride(); 143 m_col_strides = tensor.impl().userColStride(); 144 145 m_in_plane_strides = tensor.impl().userInPlaneStride(); 146 m_in_row_strides = tensor.impl().userInRowStride(); 147 m_in_col_strides = tensor.impl().userInColStride(); 148 149 m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); 150 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 151 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 152 153 if (internal::traits<ArgType>::Layout == ColMajor) { 154 m_inputDepth = tensor.impl().impl().dimensions()[0]; 155 m_inputPlanes = tensor.impl().impl().dimensions()[1]; 156 m_inputRows = tensor.impl().impl().dimensions()[2]; 157 m_inputCols = tensor.impl().impl().dimensions()[3]; 158 } else { 159 const int NumDims = tensor.impl().impl().dimensions().size(); 160 m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; 161 m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; 162 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; 163 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; 164 } 165 166 // Strides for navigating through the input tensor. 167 m_planeInputStride = m_inputDepth; 168 m_rowInputStride = m_inputDepth * m_inputPlanes; 169 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 170 m_patchInputStride = 171 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 172 173 m_planePaddingTop = tensor.impl().planePaddingTop(); 174 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 175 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 176 177 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 178 179 m_fastPatchPlaneStride = 180 internal::TensorIntDivisor<Index>(m_patch_plane_stride); 181 m_fastPatchRowStride = 182 internal::TensorIntDivisor<Index>(m_patch_row_stride); 183 m_fastPatchColStride = 184 internal::TensorIntDivisor<Index>(m_patch_col_stride); 185 186 m_fastInputPlaneStride = 187 internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides); 188 m_fastInputRowStride = 189 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 190 m_fastInputColStride = 191 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 192 193 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 194 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 195 196 m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth); 197 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 198 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 199 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 200 m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols); 201 202 m_fastOutputPlanesRows = 203 internal::TensorIntDivisor<Index>(m_outputPlanesRows); 204 } 205 206 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)207 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 208 : m_impl(base_mapper.m_impl) { 209 m_patch_depth = base_mapper.m_patch_depth; 210 m_patch_planes = base_mapper.m_patch_planes; 211 m_patch_rows = base_mapper.m_patch_rows; 212 m_patch_cols = base_mapper.m_patch_cols; 213 m_num_patches = base_mapper.m_num_patches; 214 215 m_patch_plane_stride = base_mapper.m_patch_plane_stride; 216 m_patch_row_stride = base_mapper.m_patch_row_stride; 217 m_patch_col_stride = base_mapper.m_patch_col_stride; 218 219 m_rowStride = base_mapper.m_rowStride; 220 m_colStride = base_mapper.m_colStride; 221 m_patchStride = base_mapper.m_patchStride; 222 m_otherStride = base_mapper.m_otherStride; 223 224 m_planeInputStride = base_mapper.m_planeInputStride; 225 m_rowInputStride = base_mapper.m_rowInputStride; 226 m_colInputStride = base_mapper.m_colInputStride; 227 m_patchInputStride = base_mapper.m_patchInputStride; 228 m_otherInputStride = base_mapper.m_otherInputStride; 229 230 m_inputDepth = base_mapper.m_inputDepth; 231 m_inputPlanes = base_mapper.m_inputPlanes; 232 m_inputRows = base_mapper.m_inputRows; 233 m_inputCols = base_mapper.m_inputCols; 234 235 m_outputPlanes = base_mapper.m_outputPlanes; 236 m_outputRows = base_mapper.m_outputRows; 237 m_outputCols = base_mapper.m_outputCols; 238 239 m_plane_strides = base_mapper.m_plane_strides; 240 m_row_strides = base_mapper.m_row_strides; 241 m_col_strides = base_mapper.m_col_strides; 242 243 m_in_plane_strides = base_mapper.m_in_plane_strides; 244 m_in_row_strides = base_mapper.m_in_row_strides; 245 m_in_col_strides = base_mapper.m_in_col_strides; 246 247 m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; 248 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 249 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 250 251 m_planePaddingTop = base_mapper.m_planePaddingTop; 252 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 253 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 254 255 m_outputPlanesRows = base_mapper.m_outputPlanesRows; 256 257 m_fastNumPatches = base_mapper.m_fastNumPatches; 258 m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride; 259 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 260 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 261 m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; 262 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 263 m_fastInputColStride = base_mapper.m_fastInputColStride; 264 m_fastRowStride = base_mapper.m_fastRowStride; 265 m_fastColStride = base_mapper.m_fastColStride; 266 m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; 267 m_fastOutputRows = base_mapper.m_fastOutputRows; 268 m_fastOutputCols = base_mapper.m_fastOutputCols; 269 m_fastDimZero = base_mapper.m_fastDimZero; 270 m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; 271 } 272 273 // If true, turns off some optimizations for loading packets since the image 274 // patches are "non-standard" such as there are non-trivial strides or 275 // inflations in the input. 276 EIGEN_DEVICE_FUNC nonStandardPatches()277 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 278 return m_in_plane_strides != 1 || m_in_row_strides != 1 || 279 m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || 280 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 281 } 282 283 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)284 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 285 return SubMapper(*this, i, j); 286 } 287 288 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)289 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 290 return LinearMapper(*this, i, j); 291 } 292 293 EIGEN_DEVICE_FUNC operator()294 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 295 Index planeIndex, rowIndex, colIndex, otherIndex; 296 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 297 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 298 } 299 300 // Load the coefficient at the patchIndex location instead of the usual 301 // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the 302 // gpu code. 303 EIGEN_DEVICE_FUNC operator()304 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 305 Index planeIndex, rowIndex, colIndex, otherIndex; 306 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 307 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 308 } 309 310 EIGEN_DEVICE_FUNC loadPacket(Index row)311 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 312 Index planeIndex, rowIndex, colIndex, otherIndex; 313 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 314 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 315 } 316 317 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 318 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 319 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)320 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 321 Index planeIndex, rowIndex, colIndex, otherIndex; 322 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 323 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 324 } 325 326 EIGEN_DEVICE_FUNC impl()327 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 328 return m_impl; 329 } 330 331 EIGEN_DEVICE_FUNC patchDepth()332 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; } 333 EIGEN_DEVICE_FUNC patchPlanes()334 EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; } 335 EIGEN_DEVICE_FUNC patchRows()336 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } 337 EIGEN_DEVICE_FUNC patchCols()338 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 339 340 private: 341 friend class TensorContractionSubMapper< 342 Scalar, Index, Side, 343 TensorEvaluator<const TensorReshapingOp< 344 NewDimension, const TensorVolumePatchOp< 345 Planes, Rows, Cols, ArgType> >, 346 Device>, 347 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 348 inner_dim_reordered, Alignment>; 349 350 // Load coefficient from a patch specified by the "within patch offset" 351 // (patchId) and the precomputed indices of the first element of the patch. 352 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)353 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, 354 Index rowIndex, Index colIndex, 355 Index otherIndex) const { 356 // Find the offset of the element wrt the location of the first element. 357 const Index patchOffset = patchId / m_fastDimZero; 358 359 const Index colOffset = patchOffset / m_fastColStride; 360 const Index inputCol = colIndex + colOffset * m_in_col_strides; 361 const Index origInputCol = 362 (m_patch_col_inflate_strides == 1) 363 ? inputCol 364 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 365 366 const Index rowOffset = 367 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 368 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 369 const Index origInputRow = 370 (m_patch_row_inflate_strides == 1) 371 ? inputRow 372 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 373 374 const Index planeOffset = 375 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 376 const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; 377 const Index origInputPlane = 378 (m_patch_plane_inflate_strides == 1) 379 ? inputPlane 380 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 381 382 if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || 383 origInputCol >= m_inputCols || origInputRow >= m_inputRows || 384 origInputPlane >= m_inputPlanes || 385 (inputCol != origInputCol * m_patch_col_inflate_strides) || 386 (inputRow != origInputRow * m_patch_row_inflate_strides) || 387 (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { 388 return Scalar(0); 389 } 390 391 const Index depth = patchId - patchOffset * patchDepth(); 392 const Index inputIndex = depth + origInputPlane * m_planeInputStride + 393 origInputRow * m_rowInputStride + 394 origInputCol * m_colInputStride + otherIndex; 395 396 return m_impl.coeff(inputIndex); 397 } 398 399 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 400 // and `in_strides` equal to 1 (template specialization without templates). 401 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)402 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, 403 Index rowIndex, Index colIndex, 404 Index otherIndex) const { 405 eigen_assert(!nonStandardPatches()); 406 407 // Find the offset of the element wrt the location of the first element. 408 const Index patchOffset = patchId / m_fastDimZero; 409 410 const Index colOffset = patchOffset / m_fastColStride; 411 const Index rowOffset = 412 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 413 const Index planeOffset = 414 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 415 416 const Index inputCol = colIndex + colOffset; 417 const Index inputRow = rowIndex + rowOffset; 418 const Index inputPlane = planeIndex + planeOffset; 419 420 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 421 inputRow >= m_inputRows || inputPlane < 0 || 422 inputPlane >= m_inputPlanes) { 423 return Scalar(0); 424 } 425 426 const Index depth = patchId - patchOffset * patchDepth(); 427 const Index inputIndex = depth + inputPlane * m_planeInputStride + 428 inputRow * m_rowInputStride + 429 inputCol * m_colInputStride + otherIndex; 430 431 return m_impl.coeff(inputIndex); 432 } 433 434 // Load packet from a patch specified by the "within patch offset" 435 // (patchId) and the precomputed indices of the first element of the patch. 436 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)437 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, 438 Index rowIndex, Index colIndex, 439 Index otherIndex) const { 440 const Index packetSize = internal::unpacket_traits<Packet>::size; 441 442 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 443 eigen_assert(patchId < 444 patchDepth() * patchPlanes() * patchRows() * patchCols()); 445 446 if (nonStandardPatches()) { 447 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 448 otherIndex); 449 } 450 typedef decltype(m_impl) TensorEvaluatorT; 451 return loadPacketStandard<Packet, TensorEvaluatorT>( 452 patchId, planeIndex, rowIndex, colIndex, otherIndex); 453 } 454 455 // Helper function to load a 'partial' packet - this is the single row part of 456 // a packet that is split across two rows (but single column). In the 457 // 'partial' packet, the elements corresponding to the row (specified through 458 // rowOffset) are loaded and the rest of the elements are zero-filled into the 459 // 'partial' packet. This function is called from 460 // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised 461 // only when the packet type supports masked load and when the partial packet 462 // load is available in the TensorEvaluator. 463 EIGEN_DEVICE_FUNC loadPartialPacketStandard(Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset,Index rowOffset)464 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( 465 Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex, 466 Index patchId, const Index span[], const Index patchOffsets[], 467 Index colOffset, Index rowOffset) const { 468 const Index inputCol = colIndex + colOffset; 469 const Index inputRow = rowIndex + rowOffset; 470 const Index planeOffsets[2] = { 471 patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride, 472 patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride}; 473 const Index inputPlanes[2] = {planeIndex + planeOffsets[0], 474 planeIndex + planeOffsets[1]}; 475 476 if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols || 477 inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { 478 // Partial packet is all zeros 479 return internal::pset1<Packet>(Scalar(0)); 480 } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 481 // From inputIndex-span[0], we need to load elements starting from index 482 // span[0] all the way upto (and including) span[1]. 483 const Index depth = patchId - patchOffsets[0] * patchDepth(); 484 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + 485 inputRow * m_rowInputStride + 486 inputCol * m_colInputStride + otherIndex; 487 return m_impl.template partialPacket<Packet>( 488 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1)); 489 } else { 490 // Using slow path for this partial packet. 491 // We need to load elements starting from index span[0] all the way upto 492 // (and including) span[1]. We split this load into 3 parts: 493 // 0 : span[0]-1 - Zeros will be loaded for these indices 494 // span[0] : span[1] - Elements will be loaded here for these indices 495 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices 496 const Index packetSize = internal::unpacket_traits<Packet>::size; 497 EIGEN_ALIGN_MAX 498 typename internal::remove_const<Scalar>::type values[packetSize]; 499 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); 500 for (int i = span[0]; i < span[1] + 1; ++i) 501 values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex, 502 colIndex, otherIndex); 503 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); 504 return internal::pload<Packet>(values); 505 } 506 } 507 508 // Helper function to load a packet that is split across two rows (but single 509 // column). If required, this function is called from loadPacketStandard() 510 // when the packet type supports masked load and when the partial packet load 511 // is available in the TensorEvaluator. 512 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumnTwoRows(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[])513 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows( 514 Index patchId, Index planeIndex, Index rowIndex, Index colIndex, 515 Index otherIndex, const Index patchOffsets[], const Index colOffsets[], 516 const Index rowOffsets[]) const { 517 eigen_assert(colOffsets[1] == colOffsets[0] && 518 rowOffsets[1] == rowOffsets[0] + 1); 519 const Index packetSize = internal::unpacket_traits<Packet>::size; 520 521 // Packet to load will be split into 2 parts where each part spans a single 522 // row and both the parts span the same column. 523 // First determine where to split. 524 const Index patchIdSplit = 525 (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) * 526 m_patch_depth) - 527 1; 528 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; 529 530 // patchIds[i]: patchId corresponding to partial packet i 531 // spans[i]: Start and end indices corresponding to the elements 532 // to be loaded for partial packet i 533 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i 534 const Index patchIds[2] = {patchId, patchIdSplit + 1}; 535 const Index spans[2][2] = {{0, patchIdSplit - patchId}, 536 {patchIdSplit - patchId + 1, packetSize - 1}}; 537 const Index patchOffsets2Cols[2][2] = { 538 {patchOffsets[0], patchOffsetSplit}, 539 {patchOffsetSplit + 1, patchOffsets[1]}}; 540 541 // Load partial packets and do bit-wise OR to generate required packet 542 return internal::por<Packet>( 543 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, 544 patchIds[0], spans[0], patchOffsets2Cols[0], 545 colOffsets[0], rowOffsets[0]), 546 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, 547 patchIds[1], spans[1], patchOffsets2Cols[1], 548 colOffsets[1], rowOffsets[1])); 549 } 550 551 // Helper function to load a packet that is present in a single column and 552 // row. If required, this function is called from loadPacketStandard(). 553 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumnSingleRow(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[],const Index inputCols[],const Index inputRows[])554 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow( 555 Index patchId, Index planeIndex, Index rowIndex, Index colIndex, 556 Index otherIndex, const Index patchOffsets[], const Index colOffsets[], 557 const Index rowOffsets[], const Index inputCols[], 558 const Index inputRows[]) const { 559 eigen_assert(colOffsets[1] == colOffsets[0] && 560 rowOffsets[1] == rowOffsets[0]); 561 const Index planeOffsets[2] = { 562 patchOffsets[0] - colOffsets[0] * m_colStride - 563 rowOffsets[0] * m_rowStride, 564 patchOffsets[1] - colOffsets[1] * m_colStride - 565 rowOffsets[1] * m_rowStride}; 566 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 567 const Index inputPlanes[2] = {planeIndex + planeOffsets[0], 568 planeIndex + planeOffsets[1]}; 569 570 if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { 571 return internal::pset1<Packet>(Scalar(0)); 572 } 573 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 574 const Index depth = patchId - patchOffsets[0] * patchDepth(); 575 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + 576 inputRows[0] * m_rowInputStride + 577 inputCols[0] * m_colInputStride + otherIndex; 578 return m_impl.template packet<Unaligned>(inputIndex); 579 } 580 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 581 otherIndex); 582 } 583 584 // Load standard packet from a patch specified by the "within patch offset" 585 // (patchId) and the precomputed indices of the first element of the patch. 586 // This function will be called if partial packet loading is not available 587 // for the TensorEvaluator or if the packet type does not support masked 588 // load. 589 template <typename PacketT, typename TensorEvaluatorT> 590 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 591 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 592 PacketT>::type loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)593 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, 594 Index colIndex, Index otherIndex) const { 595 const Index packetSize = internal::unpacket_traits<Packet>::size; 596 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 597 eigen_assert(patchId < 598 patchDepth() * patchPlanes() * patchRows() * patchCols()); 599 eigen_assert(!nonStandardPatches()); 600 601 if ((patchDepth() % packetSize) == 0) { 602 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, 603 otherIndex); 604 } else { 605 // Offsets and input calculation here are identical to 606 // loadCoeffStandard(...), but repeated twice. 607 608 const Index patchOffsets[2] = { 609 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 610 611 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 612 patchOffsets[1] / m_fastColStride}; 613 eigen_assert(colOffsets[0] <= colOffsets[1]); 614 615 const Index inputCols[2] = {colIndex + colOffsets[0], 616 colIndex + colOffsets[1]}; 617 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 618 return internal::pset1<Packet>(Scalar(0)); 619 } 620 621 if (inputCols[0] == inputCols[1]) { 622 const Index rowOffsets[2] = { 623 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 624 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 625 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 626 const Index inputRows[2] = {rowIndex + rowOffsets[0], 627 rowIndex + rowOffsets[1]}; 628 629 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 630 return internal::pset1<Packet>(Scalar(0)); 631 } 632 633 if (inputRows[0] == inputRows[1]) { 634 return loadPacketStandardFromSingleColumnSingleRow( 635 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 636 colOffsets, rowOffsets, inputCols, inputRows); 637 } 638 } 639 } 640 641 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 642 otherIndex); 643 } 644 645 // Load standard packet from a patch specified by the "within patch offset" 646 // (patchId) and the precomputed indices of the first element of the patch. 647 // This function will be called if partial packet loading is available for 648 // the TensorEvaluator and if the packet type supports masked load. 649 // The only difference between this and the other case is that if the packet 650 // to load is split across two rows (but in same column), then in this case 651 // instead of going to the slow (element-by-element) load, we load two packets 652 // - each containing elements from one of the rows (rest of the elements of 653 // the packets are zeroes), and then combine these two packets to generate the 654 // required packet. The idea is to enable fast load (if possible) of these 655 // 'partial' packets. 656 template <typename PacketT, typename TensorEvaluatorT> 657 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 658 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 659 PacketT>::type loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)660 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, 661 Index colIndex, Index otherIndex) const { 662 const Index packetSize = internal::unpacket_traits<Packet>::size; 663 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 664 eigen_assert(patchId < 665 patchDepth() * patchPlanes() * patchRows() * patchCols()); 666 eigen_assert(!nonStandardPatches()); 667 668 if ((patchDepth() % packetSize) == 0) { 669 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, 670 otherIndex); 671 } else { 672 // Offsets and input calculation here are identical to 673 // loadCoeffStandard(...), but repeated twice. 674 675 const Index patchOffsets[2] = { 676 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 677 678 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 679 patchOffsets[1] / m_fastColStride}; 680 eigen_assert(colOffsets[0] <= colOffsets[1]); 681 682 const Index inputCols[2] = {colIndex + colOffsets[0], 683 colIndex + colOffsets[1]}; 684 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 685 return internal::pset1<Packet>(Scalar(0)); 686 } 687 688 if (inputCols[0] == inputCols[1]) { 689 const Index rowOffsets[2] = { 690 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 691 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 692 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 693 const Index inputRows[2] = {rowIndex + rowOffsets[0], 694 rowIndex + rowOffsets[1]}; 695 696 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 697 return internal::pset1<Packet>(Scalar(0)); 698 } 699 700 if (inputRows[0] == inputRows[1]) { 701 return loadPacketStandardFromSingleColumnSingleRow( 702 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 703 colOffsets, rowOffsets, inputCols, inputRows); 704 } 705 if (inputRows[0] + 1 == inputRows[1]) { 706 return loadPacketStandardFromSingleColumnTwoRows( 707 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 708 colOffsets, rowOffsets); 709 } 710 } 711 } 712 713 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 714 otherIndex); 715 } 716 717 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)718 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, 719 Index rowIndex, Index colIndex, 720 Index otherIndex) const { 721 const Index packetSize = internal::unpacket_traits<Packet>::size; 722 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 723 eigen_assert(patchId < 724 patchDepth() * patchPlanes() * patchRows() * patchCols()); 725 726 eigen_assert(!nonStandardPatches()); 727 eigen_assert((patchDepth() % packetSize) == 0); 728 729 // Find the offset of the element wrt the location of the first element. 730 const Index patchOffset = patchId / m_fastDimZero; 731 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 732 733 const Index colOffset = patchOffset / m_fastColStride; 734 const Index rowOffset = 735 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 736 const Index planeOffset = 737 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 738 739 const Index inputCol = colIndex + colOffset; 740 const Index inputRow = rowIndex + rowOffset; 741 const Index inputPlane = planeIndex + planeOffset; 742 743 if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || 744 inputCol >= m_inputCols || inputRow >= m_inputRows || 745 inputPlane >= m_inputPlanes) { 746 return internal::pset1<Packet>(Scalar(0)); 747 } 748 749 const Index depth = patchId - patchOffset * patchDepth(); 750 const Index inputIndex = depth + inputPlane * m_planeInputStride + 751 inputRow * m_rowInputStride + 752 inputCol * m_colInputStride + otherIndex; 753 return m_impl.template packet<Unaligned>(inputIndex); 754 } 755 756 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)757 packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, 758 Index colIndex, Index otherIndex) const { 759 const int packetSize = internal::unpacket_traits<Packet>::size; 760 EIGEN_ALIGN_MAX 761 typename internal::remove_const<Scalar>::type values[packetSize]; 762 for (int i = 0; i < packetSize; ++i) { 763 values[i] = 764 loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); 765 } 766 Packet rslt = internal::pload<Packet>(values); 767 return rslt; 768 } 769 770 // Precompute the indices (plane, row, col, other) of the first element of 771 // the given patch index, within the output tensor of the TensorVolumePatchOp. computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)772 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 773 Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, 774 Index& otherIndex) const { 775 const size_t NumInputDims = array_size< 776 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 777 778 // Check if patchIndex might contain batch and other dimensions. 779 otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; 780 781 // Compute index of the patch within the batch (and other dimensions). 782 const Index patch3DIndex = (NumInputDims == 4) 783 ? patchIndex 784 : (patchIndex - otherIndex * m_num_patches); 785 786 otherIndex *= m_patchInputStride; 787 788 colIndex = patch3DIndex / m_fastOutputPlanesRows; 789 rowIndex = 790 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 791 planeIndex = 792 patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; 793 794 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 795 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 796 planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; 797 } 798 799 Index m_patch_depth; // number of channels in the patch 800 Index m_patch_planes; // number of planes in the patch 801 Index m_patch_rows; // number of rows in the patch 802 Index m_patch_cols; // number of columns in the patch 803 Index m_num_patches; // number of patches to extract 804 805 // Strides for navigating through the single patch. 806 Index m_patch_plane_stride; 807 Index m_patch_row_stride; 808 Index m_patch_col_stride; 809 810 // Strides for the output tensor (depth is not the part of the stride). 811 Index m_rowStride; 812 Index m_colStride; 813 Index m_patchStride; 814 Index m_otherStride; 815 816 Index m_planeInputStride; // Plane stride in the input tensor 817 Index m_rowInputStride; // Row stride in the input tensor 818 Index m_colInputStride; // Col stride in the input tensor 819 Index m_patchInputStride; // Patch stride in the input tensor 820 Index m_otherInputStride; 821 822 Index m_inputDepth; // Depth of the input tensor 823 Index m_inputPlanes; // Number of planes in the input tensor 824 Index m_inputRows; // Number of rows in the input tensor 825 Index m_inputCols; // Number of cols in the input tensor 826 827 Index m_outputPlanes; // Number of output planes 828 Index m_outputRows; // Number of output rows 829 Index m_outputCols; // Number of output cols 830 Index m_outputPlanesRows; // Cached outputPlanes * outputRows. 831 832 Index m_plane_strides; // User specified plane stride 833 Index m_row_strides; // User specified row stride 834 Index m_col_strides; // User specified col stride 835 836 // User specified plane/row/col atrous convolution strides. 837 Index m_in_plane_strides; 838 Index m_in_row_strides; 839 Index m_in_col_strides; 840 841 // User specified plane/row/col inflation strides in the image patch. 842 Index m_patch_plane_inflate_strides; 843 Index m_patch_row_inflate_strides; 844 Index m_patch_col_inflate_strides; 845 846 Index m_planePaddingTop; // Plane padding 847 Index m_rowPaddingTop; // Row padding 848 Index m_colPaddingLeft; // Column padding 849 850 // Fast representation of various divisors. 851 internal::TensorIntDivisor<Index> m_fastNumPatches; 852 853 internal::TensorIntDivisor<Index> m_fastPatchPlaneStride; 854 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 855 internal::TensorIntDivisor<Index> m_fastPatchColStride; 856 857 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 858 internal::TensorIntDivisor<Index> m_fastInputRowStride; 859 internal::TensorIntDivisor<Index> m_fastInputColStride; 860 861 internal::TensorIntDivisor<Index> m_fastRowStride; 862 internal::TensorIntDivisor<Index> m_fastColStride; 863 864 internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth 865 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 866 internal::TensorIntDivisor<Index> m_fastOutputRows; 867 internal::TensorIntDivisor<Index> m_fastOutputCols; 868 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 869 870 const TensorEvaluator<ArgType, Device> m_impl; 871 }; 872 873 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 874 typename ArgType, typename Device, typename Scalar, typename Index, 875 typename nocontract_t, typename contract_t, int Side, int packet_size, 876 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 877 class TensorContractionSubMapper< 878 Scalar, Index, Side, 879 TensorEvaluator<const TensorReshapingOp<NewDimension, 880 const TensorVolumePatchOp< 881 Planes, Rows, Cols, ArgType> >, 882 Device>, 883 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 884 inner_dim_reordered, Alignment> { 885 public: 886 typedef typename packet_traits<Scalar>::type Packet; 887 typedef typename packet_traits<Scalar>::half HalfPacket; 888 889 typedef TensorContractionInputMapper< 890 Scalar, Index, Side, 891 TensorEvaluator<const TensorReshapingOp< 892 NewDimension, const TensorVolumePatchOp< 893 Planes, Rows, Cols, ArgType> >, 894 Device>, 895 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 896 inner_dim_reordered, Alignment> 897 ParentMapper; 898 typedef TensorContractionSubMapper< 899 Scalar, Index, Side, 900 TensorEvaluator<const TensorReshapingOp< 901 NewDimension, const TensorVolumePatchOp< 902 Planes, Rows, Cols, ArgType> >, 903 Device>, 904 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 905 inner_dim_reordered, Alignment> 906 Self; 907 typedef Self LinearMapper; 908 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)909 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 910 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 911 : m_base_mapper(base_mapper), 912 m_depth_offset(vert_offset), 913 m_col_offset(horiz_offset) { 914 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 915 m_colIndex, m_otherIndex); 916 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)917 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 918 const Self& base_mapper, Index vert_offset, Index horiz_offset) 919 : m_base_mapper(base_mapper.m_base_mapper), 920 m_depth_offset(vert_offset + base_mapper.m_depth_offset), 921 m_col_offset(horiz_offset + base_mapper.m_col_offset) { 922 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 923 m_colIndex, m_otherIndex); 924 } operator()925 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 926 return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, 927 m_colIndex, m_otherIndex); 928 } operator()929 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 930 Index j) const { 931 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 932 } 933 loadPacket(Index i)934 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 935 return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, 936 m_rowIndex, m_colIndex, m_otherIndex); 937 } loadPacket(Index i,Index j)938 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 939 Index j) const { 940 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 941 j + m_col_offset); 942 } 943 944 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)945 loadCoeffStandard(Index i) const { 946 return m_base_mapper.loadCoeffStandard( 947 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 948 } 949 loadPacketFast(Index i)950 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 951 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, 952 m_rowIndex, m_colIndex, m_otherIndex); 953 } 954 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)955 loadPacketStandard(Index i) const { 956 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; 957 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>( 958 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 959 } 960 template <typename Packet> aligned(Index)961 EIGEN_DEVICE_FUNC bool aligned(Index) const { 962 return false; 963 } 964 965 EIGEN_DEVICE_FUNC nonStandardPatches()966 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 967 return m_base_mapper.nonStandardPatches(); 968 } 969 970 // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row, 971 // plane and depth index respectively that fits into the peeled_k elements 972 // starting at m_depth_offset. 973 974 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)975 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 976 const Index max_col = 977 fastPatchColStride().divide(m_depth_offset + peeled_k); 978 return std::min<Index>(1 + max_col, patchCols()); 979 } 980 981 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)982 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 983 const Index col) const { 984 const Index max_row = fastPatchRowStride().divide( 985 m_depth_offset + peeled_k - col * patchColStride()); 986 return std::min<Index>(1 + max_row, patchRows()); 987 } 988 989 EIGEN_DEVICE_FUNC maxPlane(const Index peeled_k,const Index col,const Index row)990 EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col, 991 const Index row) const { 992 const Index max_plane = fastPatchPlaneStride().divide( 993 m_depth_offset + peeled_k - col * patchColStride() - 994 row * patchRowStride()); 995 return std::min<Index>(1 + max_plane, patchPlanes()); 996 } 997 998 // MaxDepth uses only the remaining number of elements in the peeled_k. 999 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)1000 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 1001 const Index start_depth) const { 1002 return std::min<Index>(start_depth + num_elements, patchDepth()); 1003 } 1004 1005 // Every register matters in this code, so sometimes to prevent register 1006 // spilling, instead of the variable that you would expect to see, we use 1007 // another one, that is guaranteed to have the same value. E.g. patch depth is 1008 // always the same as input depth, and it's also the same as input plane 1009 // stride. Bunch of other parameters have similar relations. 1010 1011 typedef internal::TensorIntDivisor<Index> IndexDivisor; 1012 1013 EIGEN_DEVICE_FUNC patchDepth()1014 EIGEN_ALWAYS_INLINE Index patchDepth() const { 1015 eigen_assert(m_base_mapper.m_patch_depth == 1016 m_base_mapper.m_planeInputStride && 1017 "Patch depth must be equal to plane input stride."); 1018 return m_base_mapper.m_planeInputStride; 1019 } 1020 1021 EIGEN_DEVICE_FUNC patchPlanes()1022 EIGEN_ALWAYS_INLINE Index patchPlanes() const { 1023 eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride && 1024 "Patch planes must be equal to row stride."); 1025 return m_base_mapper.m_rowStride; 1026 } 1027 EIGEN_DEVICE_FUNC patchRows()1028 EIGEN_ALWAYS_INLINE Index patchRows() const { 1029 return m_base_mapper.m_patch_rows; 1030 } 1031 EIGEN_DEVICE_FUNC patchCols()1032 EIGEN_ALWAYS_INLINE Index patchCols() const { 1033 return m_base_mapper.m_patch_cols; 1034 } 1035 1036 EIGEN_DEVICE_FUNC patchPlaneStride()1037 EIGEN_ALWAYS_INLINE Index patchPlaneStride() const { 1038 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 1039 "Patch depth must be equal to patch plane stride."); 1040 return patchDepth(); 1041 } 1042 EIGEN_DEVICE_FUNC patchRowStride()1043 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 1044 return m_base_mapper.m_patch_row_stride; 1045 } 1046 EIGEN_DEVICE_FUNC patchColStride()1047 EIGEN_ALWAYS_INLINE Index patchColStride() const { 1048 return m_base_mapper.m_patch_col_stride; 1049 } 1050 1051 EIGEN_DEVICE_FUNC fastPatchPlaneStride()1052 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const { 1053 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 1054 "Patch depth must be equal to patch plane stride."); 1055 return m_base_mapper.m_fastDimZero; // patch_depth 1056 } 1057 EIGEN_DEVICE_FUNC fastPatchRowStride()1058 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 1059 return m_base_mapper.m_fastPatchRowStride; 1060 } 1061 EIGEN_DEVICE_FUNC fastPatchColStride()1062 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 1063 return m_base_mapper.m_fastPatchColStride; 1064 } 1065 1066 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)1067 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 1068 const Index baseIndex) const { 1069 const Index inputIndex = depth + baseIndex; 1070 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 1071 } 1072 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)1073 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 1074 const Index baseIndex) const { 1075 const Index inputIndex = depth + baseIndex; 1076 return m_base_mapper.m_impl.coeff(inputIndex); 1077 } 1078 1079 EIGEN_DEVICE_FUNC padPlane(const Index plane)1080 EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { 1081 const Index p = m_planeIndex + plane; 1082 return p < 0 || p >= m_base_mapper.m_inputPlanes; 1083 } 1084 EIGEN_DEVICE_FUNC padRow(const Index row)1085 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 1086 const Index r = m_rowIndex + row; 1087 return r < 0 || r >= m_base_mapper.m_inputRows; 1088 } 1089 EIGEN_DEVICE_FUNC padCol(const Index col)1090 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 1091 const Index c = m_colIndex + col; 1092 return c < 0 || c >= m_base_mapper.m_inputCols; 1093 } 1094 EIGEN_DEVICE_FUNC baseIndex(const Index plane,const Index row,const Index col)1095 EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, 1096 const Index col) const { 1097 const Index p = m_planeIndex + plane; 1098 const Index r = m_rowIndex + row; 1099 const Index c = m_colIndex + col; 1100 return p * m_base_mapper.m_planeInputStride + 1101 r * m_base_mapper.m_rowInputStride + 1102 c * m_base_mapper.m_colInputStride + m_otherIndex; 1103 } 1104 1105 EIGEN_DEVICE_FUNC planeOffset()1106 EIGEN_ALWAYS_INLINE Index planeOffset() const { 1107 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1108 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1109 const Index rowOffset = 1110 (patchOffset - colOffset * m_base_mapper.m_colStride) / 1111 m_base_mapper.m_fastRowStride; 1112 const Index planeOffset = patchOffset - 1113 colOffset * m_base_mapper.m_colStride - 1114 rowOffset * m_base_mapper.m_rowStride; 1115 return planeOffset; 1116 } 1117 1118 EIGEN_DEVICE_FUNC rowOffset()1119 EIGEN_ALWAYS_INLINE Index rowOffset() const { 1120 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1121 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1122 const Index rowOffset = 1123 (patchOffset - colOffset * m_base_mapper.m_colStride) / 1124 m_base_mapper.m_fastRowStride; 1125 return rowOffset; 1126 } 1127 1128 EIGEN_DEVICE_FUNC colOffset()1129 EIGEN_ALWAYS_INLINE Index colOffset() const { 1130 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1131 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1132 return colOffset; 1133 } 1134 1135 EIGEN_DEVICE_FUNC depthOffset()1136 EIGEN_ALWAYS_INLINE Index depthOffset() const { 1137 return m_depth_offset % patchDepth(); 1138 } 1139 1140 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)1141 getLinearMapper(Index i, Index j) const { 1142 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 1143 } 1144 1145 private: 1146 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 1147 // performs better in benchmarks. 1148 1149 Index m_depth_offset; // First row in the input matrix 1150 Index m_col_offset; // First col in the input matrix 1151 1152 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 1153 // indices for the first element in a patch specified by col_offset 1154 // (see computeBaseIndices(...) for details). 1155 Index m_planeIndex; 1156 Index m_rowIndex; 1157 Index m_colIndex; 1158 Index m_otherIndex; 1159 }; 1160 1161 // Arrange a block of the right input matrix (in our case it's always a "virtual 1162 // matrix" constructed from extracted volume patches) in contiguous memory. 1163 // 1164 // Given column major input (A0 beside A1 in memory): 1165 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 1166 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 1167 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 1168 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 1169 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 1170 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 1171 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 1172 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 1173 // A8 ... 1174 // ... 1175 // 1176 // *) A, B, C, ... - patches extracted from the original input. 1177 // *) A0, A1, A2 ... - values from the same patch at different offsets. 1178 // 1179 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 1180 // A0 B0 C0 D0 A1 B1 C1 D1 ... 1181 // E0 F0 G0 H0 E1 F1 G1 H1 ... 1182 // ... 1183 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 1184 // 1185 // This traversal order must be the same as in default gemm_pack_rhs defined in 1186 // GeneralBlockPanelKernel.h. 1187 // 1188 // *) nr - number of registers along the 'n' dimension. 1189 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 1190 // Multiplication" paper. 1191 // 1192 // TODO(ezhulenev): Add support for squeezing reads along two innermost 1193 // dimensions (see eigen_spatial_convolutions). 1194 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1195 typename ArgType, typename Device, typename Scalar, typename Index, 1196 typename nocontract_t, typename contract_t, int packet_size, 1197 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 1198 int nr> 1199 struct gemm_pack_rhs< 1200 Scalar, Index, 1201 TensorContractionSubMapper< 1202 Scalar, Index, Rhs, 1203 TensorEvaluator<const TensorReshapingOp< 1204 NewDimension, const TensorVolumePatchOp< 1205 Planes, Rows, Cols, ArgType> >, 1206 Device>, 1207 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1208 inner_dim_reordered, Alignment>, 1209 nr, ColMajor, false, false> { 1210 typedef TensorContractionSubMapper< 1211 Scalar, Index, Rhs, 1212 TensorEvaluator<const TensorReshapingOp< 1213 NewDimension, const TensorVolumePatchOp< 1214 Planes, Rows, Cols, ArgType> >, 1215 Device>, 1216 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1217 inner_dim_reordered, Alignment> 1218 SubMapper; 1219 1220 typedef SubMapper DataMapper; 1221 typedef typename packet_traits<Scalar>::type Packet; 1222 1223 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1224 1225 EIGEN_DEVICE_FUNC 1226 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1227 Index depth, Index cols, Index stride = 0, 1228 Index offset = 0) const { 1229 eigen_assert(stride == 0); 1230 eigen_assert(offset == 0); 1231 1232 const Index packet_cols4 = (cols / 4) * 4; 1233 const Index peeled_k = (depth / packet_size) * packet_size; 1234 const bool non_standard_patches = rhs.nonStandardPatches(); 1235 1236 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1237 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1238 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1239 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1240 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1241 1242 Index k = 0; 1243 if ((packet_size % 4) == 0 && !non_standard_patches) { 1244 // FAST PATH: 1245 // Iterate over patch columns, rows and planes if we know that a single 1246 // packet do not span across multiple planes, rows or columns. 1247 if ((rhs.patchDepth() % packet_size) == 0) { 1248 const Index start_col = rhs.colOffset(); 1249 const Index max_col = rhs.maxCol(peeled_k); 1250 1251 for (Index c = start_col; c < max_col; ++c) { 1252 eigen_assert(k <= peeled_k); 1253 1254 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1255 const Index max_row = rhs.maxRow(peeled_k, c); 1256 1257 const bool pad_col0 = dm0.padCol(c); 1258 const bool pad_col1 = dm1.padCol(c); 1259 const bool pad_col2 = dm2.padCol(c); 1260 const bool pad_col3 = dm3.padCol(c); 1261 1262 for (Index r = start_row; r < max_row; ++r) { 1263 eigen_assert(k <= peeled_k); 1264 1265 const Index start_plane = ((c == start_col) && (r == start_row)) 1266 ? rhs.planeOffset() 1267 : 0; 1268 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1269 1270 const bool pad_row0 = pad_col0 || dm0.padRow(r); 1271 const bool pad_row1 = pad_col1 || dm1.padRow(r); 1272 const bool pad_row2 = pad_col2 || dm2.padRow(r); 1273 const bool pad_row3 = pad_col3 || dm3.padRow(r); 1274 1275 for (Index p = start_plane; p < max_plane; ++p) { 1276 eigen_assert(k <= peeled_k); 1277 1278 const bool pad0 = pad_row0 || dm0.padPlane(p); 1279 const bool pad1 = pad_row1 || dm1.padPlane(p); 1280 const bool pad2 = pad_row2 || dm2.padPlane(p); 1281 const bool pad3 = pad_row3 || dm3.padPlane(p); 1282 1283 const Index idx0 = dm0.baseIndex(p, r, c); 1284 const Index idx1 = dm1.baseIndex(p, r, c); 1285 const Index idx2 = dm2.baseIndex(p, r, c); 1286 const Index idx3 = dm3.baseIndex(p, r, c); 1287 1288 const Index start_depth = 1289 ((c == start_col) && (r == start_row) && (p == start_plane)) 1290 ? rhs.depthOffset() 1291 : 0; 1292 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1293 eigen_assert((max_depth - start_depth) % packet_size == 0); 1294 1295 for (Index d = start_depth; d < max_depth; d += packet_size) { 1296 eigen_assert(k < peeled_k); 1297 PacketBlock<Packet, 4> kernel; 1298 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1299 : rhs.packetNoPadding(d, idx0); 1300 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1301 : rhs.packetNoPadding(d, idx1); 1302 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 1303 : rhs.packetNoPadding(d, idx2); 1304 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 1305 : rhs.packetNoPadding(d, idx3); 1306 ptranspose(kernel); 1307 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1308 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1309 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1310 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1311 block += 4 * packet_size; 1312 k += packet_size; 1313 } 1314 } 1315 } 1316 } 1317 1318 // The loop above should fill peeled_k elements. 1319 eigen_assert(peeled_k == k); 1320 1321 } else { 1322 // Packet can span multiple planes, rows or columns, so we have to go 1323 // though the slower "standard" path. 1324 for (; k < peeled_k; k += packet_size) { 1325 PacketBlock<Packet, 4> kernel; 1326 kernel.packet[0] = dm0.loadPacketStandard(k); 1327 kernel.packet[1] = dm1.loadPacketStandard(k); 1328 kernel.packet[2] = dm2.loadPacketStandard(k); 1329 kernel.packet[3] = dm3.loadPacketStandard(k); 1330 ptranspose(kernel); 1331 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1332 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1333 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1334 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1335 block += 4 * packet_size; 1336 } 1337 } 1338 } 1339 1340 // Copy the remaining coefficients of the column block after the peeled_k. 1341 if (!non_standard_patches) { 1342 for (; k < depth; k++) { 1343 block[0] = dm0.loadCoeffStandard(k); 1344 block[1] = dm1.loadCoeffStandard(k); 1345 block[2] = dm2.loadCoeffStandard(k); 1346 block[3] = dm3.loadCoeffStandard(k); 1347 block += 4; 1348 } 1349 } else { 1350 for (; k < depth; k++) { 1351 block[0] = dm0(k); 1352 block[1] = dm1(k); 1353 block[2] = dm2(k); 1354 block[3] = dm3(k); 1355 block += 4; 1356 } 1357 } 1358 } 1359 1360 // Copy the remaining columns one at a time (nr==1). 1361 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1362 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1363 for (Index k = 0; k < depth; k++) { 1364 *block = dm0(k); 1365 block += 1; 1366 } 1367 } 1368 } 1369 }; 1370 1371 // Template specialization for packet_size = 2. We must special-case packet 1372 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1373 // 1374 // TODO(ezhulenev): Add support for squeezing reads along two innermost 1375 // dimensions (see eigen_spatial_convolutions). 1376 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1377 typename ArgType, typename Device, typename Scalar, typename Index, 1378 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1379 bool inner_dim_reordered, int Alignment, int nr> 1380 struct gemm_pack_rhs< 1381 Scalar, Index, 1382 TensorContractionSubMapper< 1383 Scalar, Index, Rhs, 1384 TensorEvaluator<const TensorReshapingOp< 1385 NewDimension, const TensorVolumePatchOp< 1386 Planes, Rows, Cols, ArgType> >, 1387 Device>, 1388 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1389 inner_dim_reordered, Alignment>, 1390 nr, ColMajor, false, false> { 1391 typedef TensorContractionSubMapper< 1392 Scalar, Index, Rhs, 1393 TensorEvaluator<const TensorReshapingOp< 1394 NewDimension, const TensorVolumePatchOp< 1395 Planes, Rows, Cols, ArgType> >, 1396 Device>, 1397 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1398 inner_dim_reordered, Alignment> 1399 SubMapper; 1400 typedef SubMapper DataMapper; 1401 typedef typename packet_traits<Scalar>::type Packet; 1402 1403 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1404 1405 EIGEN_DEVICE_FUNC 1406 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1407 Index depth, Index cols, Index stride = 0, 1408 Index offset = 0) const { 1409 eigen_assert(stride == 0); 1410 eigen_assert(offset == 0); 1411 1412 const int packet_size = 2; 1413 1414 const Index packet_cols4 = (cols / 4) * 4; 1415 const Index peeled_k = (depth / packet_size) * packet_size; 1416 const bool non_standard_patches = rhs.nonStandardPatches(); 1417 1418 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1419 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1420 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1421 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1422 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1423 1424 Index k = 0; 1425 if (!non_standard_patches) { 1426 // FAST PATH: 1427 // Iterate over patch columns, rows and planes if we know that a single 1428 // packet do not span across multiple planes, rows or columns. 1429 if ((rhs.patchDepth() % packet_size) == 0) { 1430 const Index start_col = rhs.colOffset(); 1431 const Index max_col = rhs.maxCol(peeled_k); 1432 1433 for (Index c = start_col; c < max_col; ++c) { 1434 eigen_assert(k <= peeled_k); 1435 1436 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1437 const Index max_row = rhs.maxRow(peeled_k, c); 1438 1439 const bool pad_col0 = dm0.padCol(c); 1440 const bool pad_col1 = dm1.padCol(c); 1441 const bool pad_col2 = dm2.padCol(c); 1442 const bool pad_col3 = dm3.padCol(c); 1443 1444 for (Index r = start_row; r < max_row; ++r) { 1445 eigen_assert(k <= peeled_k); 1446 1447 const Index start_plane = ((c == start_col) && (r == start_row)) 1448 ? rhs.planeOffset() 1449 : 0; 1450 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1451 1452 const bool pad_row0 = dm0.padRow(r); 1453 const bool pad_row1 = dm1.padRow(r); 1454 const bool pad_row2 = dm2.padRow(r); 1455 const bool pad_row3 = dm3.padRow(r); 1456 1457 for (Index p = start_plane; p < max_plane; ++p) { 1458 eigen_assert(k <= peeled_k); 1459 1460 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); 1461 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); 1462 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); 1463 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); 1464 1465 const Index idx0 = dm0.baseIndex(p, r, c); 1466 const Index idx1 = dm1.baseIndex(p, r, c); 1467 const Index idx2 = dm2.baseIndex(p, r, c); 1468 const Index idx3 = dm3.baseIndex(p, r, c); 1469 1470 const Index start_depth = 1471 ((c == start_col) && (r == start_row) && (p == start_plane)) 1472 ? rhs.depthOffset() 1473 : 0; 1474 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1475 eigen_assert((max_depth - start_depth) % packet_size == 0); 1476 1477 for (Index d = start_depth; d < max_depth; d += packet_size) { 1478 eigen_assert(k < peeled_k); 1479 PacketBlock<Packet, 2> kernel0; 1480 PacketBlock<Packet, 2> kernel1; 1481 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1482 : rhs.packetNoPadding(d, idx0); 1483 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1484 : rhs.packetNoPadding(d, idx1); 1485 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1486 : rhs.packetNoPadding(d, idx2); 1487 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1488 : rhs.packetNoPadding(d, idx3); 1489 ptranspose(kernel0); 1490 ptranspose(kernel1); 1491 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1492 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1493 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1494 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1495 block += 4 * packet_size; 1496 k += packet_size; 1497 } 1498 } 1499 } 1500 } 1501 1502 // The loop above should fill peeled_k elements. 1503 eigen_assert(peeled_k == k); 1504 1505 } else { 1506 for (; k < peeled_k; k += packet_size) { 1507 PacketBlock<Packet, 2> kernel0; 1508 PacketBlock<Packet, 2> kernel1; 1509 kernel0.packet[0] = dm0.loadPacketStandard(k); 1510 kernel0.packet[1] = dm1.loadPacketStandard(k); 1511 kernel1.packet[0] = dm2.loadPacketStandard(k); 1512 kernel1.packet[1] = dm3.loadPacketStandard(k); 1513 ptranspose(kernel0); 1514 ptranspose(kernel1); 1515 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1516 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1517 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1518 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1519 block += 4 * packet_size; 1520 } 1521 } 1522 } 1523 1524 // Copy the remaining coefficients of the column block after the peeled_k. 1525 if (!rhs.nonStandardPatches()) { 1526 for (; k < depth; k++) { 1527 block[0] = dm0.loadCoeffStandard(k); 1528 block[1] = dm1.loadCoeffStandard(k); 1529 block[2] = dm2.loadCoeffStandard(k); 1530 block[3] = dm3.loadCoeffStandard(k); 1531 block += 4; 1532 } 1533 } else { 1534 for (; k < depth; k++) { 1535 block[0] = dm0(k); 1536 block[1] = dm1(k); 1537 block[2] = dm2(k); 1538 block[3] = dm3(k); 1539 block += 4; 1540 } 1541 } 1542 } 1543 1544 // Copy the remaining columns one at a time (nr==1). 1545 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1546 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1547 for (Index k = 0; k < depth; k++) { 1548 *block = dm0(k); 1549 block += 1; 1550 } 1551 } 1552 } 1553 }; 1554 1555 // Special case for non-vectorized types such as float16 (packet_size = 1). 1556 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1557 typename ArgType, typename Device, typename Scalar, typename Index, 1558 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1559 bool inner_dim_reordered, int Alignment, int nr> 1560 struct gemm_pack_rhs< 1561 Scalar, Index, 1562 TensorContractionSubMapper< 1563 Scalar, Index, Rhs, 1564 TensorEvaluator<const TensorReshapingOp< 1565 NewDimension, const TensorVolumePatchOp< 1566 Planes, Rows, Cols, ArgType> >, 1567 Device>, 1568 nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, 1569 inner_dim_reordered, Alignment>, 1570 nr, ColMajor, false, false> { 1571 typedef TensorContractionSubMapper< 1572 Scalar, Index, Rhs, 1573 TensorEvaluator<const TensorReshapingOp< 1574 NewDimension, const TensorVolumePatchOp< 1575 Planes, Rows, Cols, ArgType> >, 1576 Device>, 1577 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1578 Alignment> 1579 SubMapper; 1580 typedef SubMapper DataMapper; 1581 1582 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1583 1584 EIGEN_DEVICE_FUNC 1585 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1586 Index depth, Index cols, Index stride = 0, 1587 Index offset = 0) const { 1588 eigen_assert(stride == 0); 1589 eigen_assert(offset == 0); 1590 1591 const Index packet_cols4 = (cols / 4) * 4; 1592 1593 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1594 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1595 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1596 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1597 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1598 1599 if (!rhs.nonStandardPatches()) { 1600 for (Index k = 0; k < depth; k++) { 1601 block[0] = dm0.loadCoeffStandard(k); 1602 block[1] = dm1.loadCoeffStandard(k); 1603 block[2] = dm2.loadCoeffStandard(k); 1604 block[3] = dm3.loadCoeffStandard(k); 1605 block += 4; 1606 } 1607 } else { 1608 for (Index k = 0; k < depth; k++) { 1609 block[0] = dm0(k); 1610 block[1] = dm1(k); 1611 block[2] = dm2(k); 1612 block[3] = dm3(k); 1613 block += 4; 1614 } 1615 } 1616 } 1617 1618 // Copy the remaining columns one at a time (nr==1). 1619 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1620 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1621 for (Index k = 0; k < depth; k++) { 1622 *block = dm0(k); 1623 block += 1; 1624 } 1625 } 1626 } 1627 }; 1628 1629 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1630 // Pack a block of the right input matrix (in our case it's always a "virtual 1631 // matrix" constructed from extracted image patches) in contiguous block in 1632 // column-major storage order. Knowing the properties of the original patch op 1633 // we can do it more efficient than the default gemm_pack_colmajor_block. 1634 // 1635 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports 1636 // squeezing reads along the 2 innermost dimensions, add it here if needed. 1637 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1638 typename ArgType, typename Device, typename Scalar, 1639 typename StorageIndex, typename nocontract_t, typename contract_t, 1640 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, 1641 int Alignment> 1642 struct gemm_pack_colmajor_block< 1643 Scalar, StorageIndex, 1644 TensorContractionSubMapper< 1645 Scalar, StorageIndex, Rhs, 1646 TensorEvaluator<const TensorReshapingOp< 1647 NewDimension, const TensorVolumePatchOp< 1648 Planes, Rows, Cols, ArgType> >, 1649 Device>, 1650 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1651 inner_dim_reordered, Alignment>, 1652 ColMajor> { 1653 typedef TensorContractionSubMapper< 1654 Scalar, StorageIndex, Rhs, 1655 TensorEvaluator<const TensorReshapingOp< 1656 NewDimension, const TensorVolumePatchOp< 1657 Planes, Rows, Cols, ArgType> >, 1658 Device>, 1659 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1660 inner_dim_reordered, Alignment> 1661 SubMapper; 1662 1663 typedef SubMapper DataMapper; 1664 typedef typename packet_traits<Scalar>::type Packet; 1665 1666 EIGEN_DONT_INLINE 1667 void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, 1668 StorageIndex cols) { 1669 const bool standard_patches = !rhs.nonStandardPatches(); 1670 1671 if (standard_patches && rhs.patchDepth() % packet_size == 0) { 1672 packStandardPatches<true>(block, rhs, rows, cols); 1673 1674 } else if (standard_patches) { 1675 packStandardPatches<false>(block, rhs, rows, cols); 1676 1677 } else { 1678 // With non-standard patches we don't do any vectorized loads. 1679 // TODO(ezhulenev): It doesn't look like that we should completely give up 1680 // on packets. Make this code path faster! 1681 for (StorageIndex col = 0; col < cols; ++col) { 1682 SubMapper lm = rhs.getLinearMapper(0, col); 1683 for (StorageIndex i = 0; i < rows; ++i) { 1684 *block = lm(i); 1685 ++block; 1686 } 1687 } 1688 } 1689 } 1690 1691 private: 1692 // Pack standard volume patches: 1693 // 1694 // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have 1695 // depth dimension size to be a multiple of packet size, so we can skip all 1696 // non vectorized loads and checks. 1697 // 1698 template <bool patch_depth_is_multiple_of_packet_size> 1699 EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block, 1700 const DataMapper& rhs, 1701 StorageIndex rows, 1702 StorageIndex cols) { 1703 eigen_assert(!rhs.nonStandardPatches()); 1704 1705 // Give vectorized_rows the name used in all other gemm_pack_rhs above. 1706 const Index peeled_k = (rows / packet_size) * packet_size; 1707 1708 const Index start_col = rhs.colOffset(); 1709 const Index max_col = rhs.maxCol(peeled_k); 1710 1711 for (StorageIndex col = 0; col < cols; ++col) { 1712 SubMapper lm = rhs.getLinearMapper(0, col); 1713 1714 Index k = 0; 1715 for (Index c = start_col; c < max_col; ++c) { 1716 eigen_assert(k <= peeled_k); 1717 1718 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1719 const Index max_row = rhs.maxRow(peeled_k, c); 1720 const bool pad_col = lm.padCol(c); 1721 1722 for (Index r = start_row; r < max_row; ++r) { 1723 eigen_assert(k <= peeled_k); 1724 1725 const Index start_plane = 1726 ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0; 1727 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1728 const bool pad_row = pad_col || lm.padRow(r); 1729 1730 for (Index p = start_plane; p < max_plane; ++p) { 1731 eigen_assert(k <= peeled_k); 1732 1733 const Index start_depth = 1734 ((c == start_col) && (r == start_row) && (p == start_plane)) 1735 ? rhs.depthOffset() 1736 : 0; 1737 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1738 1739 const bool pad = pad_col || pad_row || lm.padPlane(p); 1740 const Index base_idx = lm.baseIndex(p, r, c); 1741 1742 if (patch_depth_is_multiple_of_packet_size) 1743 eigen_assert((max_depth - start_depth) % packet_size == 0); 1744 1745 // If patch depth is a multiple of packet size, it's guaranteed that 1746 // we can process all values in depth dimension with packets. 1747 const Index max_vectorized_depth = 1748 patch_depth_is_multiple_of_packet_size 1749 ? max_depth 1750 : max_depth - packet_size; 1751 1752 Index d = start_depth; 1753 1754 // 1. Process depth dimension with vectorized instructions. 1755 for (; d < max_vectorized_depth; d += packet_size) { 1756 eigen_assert(k < peeled_k); 1757 const Packet packet = pad ? pset1<Packet>(Scalar(0)) 1758 : rhs.packetNoPadding(d, base_idx); 1759 internal::pstoreu(block, packet); 1760 block += packet_size; 1761 k += packet_size; 1762 } 1763 1764 // 2. Finish with coefficients. 1765 if (!patch_depth_is_multiple_of_packet_size) { 1766 for (; d < max_depth; d++) { 1767 eigen_assert(k < peeled_k); 1768 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx); 1769 ++block; 1770 ++k; 1771 } 1772 } 1773 } 1774 } 1775 } 1776 1777 // The loop above should fill peeled_k elements. 1778 eigen_assert(peeled_k == k); 1779 1780 // Fill remaining elements using loadCoeffStandard. 1781 for (; k < rows; ++k) { 1782 *block = lm.loadCoeffStandard(k); 1783 ++block; 1784 } 1785 } 1786 } 1787 }; 1788 #endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1789 1790 } // namespace internal 1791 1792 /** CuboidConvolution 1793 * \ingroup CXX11_NeuralNetworks_Module 1794 * 1795 * \brief Applies a 3D convolution over a multichannel input voxel block. 1796 * 1797 * The input parameter is expected to be a tensor with a rank of 4 or more 1798 * (channels, depth, height, width, and optionally others). 1799 * The kernel parameter is expected to be a 5D tensor (filters, channels, 1800 * kernel_depth, kernel_height, kernel_width). 1801 * The result can be assigned to a tensor of rank equal to the rank of the 1802 * input. The dimensions of the result will be filters, depth, height, width 1803 * (and others if applicable). 1804 * 1805 * The input and kernel have to be in the same layout, and both row-major and 1806 * col-major are supported. The shapes given above are for col-major layout. 1807 * For row-major, all dimensions should be reversed. 1808 * 1809 * It is possible to swap the order of the depth, width, and height dimensions 1810 * provided that the same order is used in the input, the kernel, and the 1811 * output. 1812 */ 1813 template <typename Input, typename Kernel> 1814 EIGEN_ALWAYS_INLINE static const typename internal::conditional< 1815 internal::traits<Input>::Layout == ColMajor, 1816 TensorReshapingOp< 1817 const DSizes<typename internal::traits<Input>::Index, 1818 internal::traits<Input>::NumDimensions>, 1819 const TensorContractionOp< 1820 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1821 const TensorReshapingOp< 1822 const DSizes<typename internal::traits<Input>::Index, 2>, 1823 const Kernel>, 1824 const TensorReshapingOp< 1825 const DSizes<typename internal::traits<Input>::Index, 2>, 1826 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1827 const Input> > > >, 1828 TensorReshapingOp< 1829 const DSizes<typename internal::traits<Input>::Index, 1830 internal::traits<Input>::NumDimensions>, 1831 const TensorContractionOp< 1832 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1833 const TensorReshapingOp< 1834 const DSizes<typename internal::traits<Input>::Index, 2>, 1835 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1836 const Input> >, 1837 const TensorReshapingOp< 1838 const DSizes<typename internal::traits<Input>::Index, 2>, 1839 const Kernel> > > >::type 1840 CuboidConvolution(const Input& input, const Kernel& kernel, 1841 const Index stridePlanes = 1, const Index strideRows = 1, 1842 const Index strideCols = 1, 1843 const PaddingType padding_type = PADDING_SAME) { 1844 typedef typename internal::traits<Input>::Index TensorIndex; 1845 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 1846 internal::traits<Input>::NumDimensions, 1847 internal::traits<Input>::Layout, TensorIndex> > 1848 in(input); 1849 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1850 internal::traits<Kernel>::NumDimensions, 1851 internal::traits<Kernel>::Layout, TensorIndex> > 1852 kern(kernel); 1853 1854 EIGEN_STATIC_ASSERT( 1855 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1856 YOU_MADE_A_PROGRAMMING_MISTAKE); 1857 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1858 static const int NumDims = internal::traits<Input>::NumDimensions; 1859 1860 // Number of filters to apply. This is the same as the output depth of the 1861 // result. 1862 const TensorIndex kernelFilters = 1863 isColMajor ? kern.dimensions()[0] : kern.dimensions()[4]; 1864 const TensorIndex kernelChannels = 1865 isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; 1866 1867 // Spatial size of the kernel. 1868 const TensorIndex kernelPlanes = 1869 isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; 1870 const TensorIndex kernelRows = 1871 isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; 1872 const TensorIndex kernelCols = 1873 isColMajor ? kern.dimensions()[4] : kern.dimensions()[0]; 1874 1875 if (isColMajor) { 1876 eigen_assert(kernelChannels == in.dimension(0)); 1877 } else { 1878 eigen_assert(kernelChannels == in.dimension(NumDims - 1)); 1879 } 1880 1881 const TensorIndex inputPlanes = 1882 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1883 const TensorIndex inputRows = 1884 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1885 const TensorIndex inputCols = 1886 isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); 1887 1888 TensorIndex out_planes; 1889 TensorIndex out_height; 1890 TensorIndex out_width; 1891 switch (padding_type) { 1892 case PADDING_VALID: 1893 out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1, 1894 static_cast<TensorIndex>(stridePlanes)); 1895 out_height = Eigen::divup(inputRows - kernelRows + 1, 1896 static_cast<TensorIndex>(strideRows)); 1897 out_width = Eigen::divup(inputCols - kernelCols + 1, 1898 static_cast<TensorIndex>(strideCols)); 1899 break; 1900 case PADDING_SAME: 1901 out_planes = 1902 Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); 1903 out_height = 1904 Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); 1905 out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); 1906 break; 1907 default: 1908 out_planes = 0; 1909 out_height = 0; 1910 out_width = 0; 1911 eigen_assert(false && "unexpected padding"); 1912 } 1913 1914 DSizes<TensorIndex, 2> kernel_dims; 1915 if (isColMajor) { 1916 kernel_dims[0] = kernelFilters; 1917 kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1918 } else { 1919 kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1920 kernel_dims[1] = kernelFilters; 1921 } 1922 1923 // Molds the output of the patch extraction result into a 2D tensor: 1924 // - the first dimension (dims[0]): the patch values to be multiplied with the 1925 // kernels 1926 // - the second dimension (dims[1]): everything else 1927 DSizes<TensorIndex, 2> pre_contract_dims; 1928 if (isColMajor) { 1929 pre_contract_dims[0] = 1930 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1931 pre_contract_dims[1] = out_planes * out_height * out_width; 1932 for (int i = 4; i < NumDims; ++i) { 1933 pre_contract_dims[1] *= in.dimension(i); 1934 } 1935 } else { 1936 pre_contract_dims[1] = 1937 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1938 pre_contract_dims[0] = out_planes * out_height * out_width; 1939 for (int i = 0; i < NumDims - 4; ++i) { 1940 pre_contract_dims[0] *= in.dimension(i); 1941 } 1942 } 1943 1944 array<IndexPair<TensorIndex>, 1> contract_dims; 1945 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1946 1947 // Molds the output of the contraction into the shape expected by the user 1948 // (assuming ColMajor): 1949 // - 1st dim: kernel filters 1950 // - 2nd dim: output depth 1951 // - 3nd dim: output height 1952 // - 4rd dim: output width 1953 // - 5th dim and beyond: everything else including batch size 1954 DSizes<TensorIndex, NumDims> post_contract_dims; 1955 if (isColMajor) { 1956 post_contract_dims[0] = kernelFilters; 1957 post_contract_dims[1] = out_planes; 1958 post_contract_dims[2] = out_height; 1959 post_contract_dims[3] = out_width; 1960 for (int i = 4; i < NumDims; ++i) { 1961 post_contract_dims[i] = in.dimension(i); 1962 } 1963 } else { 1964 post_contract_dims[NumDims - 1] = kernelFilters; 1965 post_contract_dims[NumDims - 2] = out_planes; 1966 post_contract_dims[NumDims - 3] = out_height; 1967 post_contract_dims[NumDims - 4] = out_width; 1968 for (int i = 0; i < NumDims - 4; ++i) { 1969 post_contract_dims[i] = in.dimension(i); 1970 } 1971 } 1972 1973 return choose( 1974 Cond<internal::traits<Input>::Layout == ColMajor>(), 1975 kernel.reshape(kernel_dims) 1976 .contract(input 1977 .extract_volume_patches( 1978 kernelPlanes, kernelRows, kernelCols, stridePlanes, 1979 strideRows, strideCols, padding_type) 1980 .reshape(pre_contract_dims), 1981 contract_dims) 1982 .reshape(post_contract_dims), 1983 input 1984 .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1985 stridePlanes, strideRows, strideCols, 1986 padding_type) 1987 .reshape(pre_contract_dims) 1988 .contract(kernel.reshape(kernel_dims), contract_dims) 1989 .reshape(post_contract_dims)); 1990 } 1991 1992 } // end namespace Eigen 1993 1994 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 1995