1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H 12 13 14 namespace Eigen { 15 namespace internal { 16 17 enum { 18 ShardByRow = 0, 19 ShardByCol = 1 20 }; 21 22 23 // Default Blocking Strategy 24 template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol> 25 class TensorContractionBlocking { 26 public: 27 28 typedef typename LhsMapper::Scalar LhsScalar; 29 typedef typename RhsMapper::Scalar RhsScalar; 30 31 EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : kc_(k)32 kc_(k), mc_(m), nc_(n) 33 { 34 if (ShardingType == ShardByCol) { 35 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads); 36 } 37 else { 38 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads); 39 } 40 } 41 kc()42 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } mc()43 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } nc()44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } 45 46 private: 47 Index kc_; 48 Index mc_; 49 Index nc_; 50 }; 51 52 53 } // end namespace internal 54 } // end namespace Eigen 55 56 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H 57