1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 11 #define EIGEN_SPARSEDENSEPRODUCT_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; 18 template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; 19 20 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, 21 typename AlphaType, 22 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, 23 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> 24 struct sparse_time_dense_product_impl; 25 26 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 27 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> 28 { 29 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 30 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 31 typedef typename internal::remove_all<DenseResType>::type Res; 32 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 33 typedef evaluator<Lhs> LhsEval; 34 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 35 { 36 LhsEval lhsEval(lhs); 37 38 Index n = lhs.outerSize(); 39 #ifdef EIGEN_HAS_OPENMP 40 Eigen::initParallel(); 41 Index threads = Eigen::nbThreads(); 42 #endif 43 44 for(Index c=0; c<rhs.cols(); ++c) 45 { 46 #ifdef EIGEN_HAS_OPENMP 47 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. 48 // It basically represents the minimal amount of work to be done to be worth it. 49 if(threads>1 && lhsEval.nonZerosEstimate() > 20000) 50 { 51 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) 52 for(Index i=0; i<n; ++i) 53 processRow(lhsEval,rhs,res,alpha,i,c); 54 } 55 else 56 #endif 57 { 58 for(Index i=0; i<n; ++i) 59 processRow(lhsEval,rhs,res,alpha,i,c); 60 } 61 } 62 } 63 64 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) 65 { 66 typename Res::Scalar tmp(0); 67 for(LhsInnerIterator it(lhsEval,i); it ;++it) 68 tmp += it.value() * rhs.coeff(it.index(),col); 69 res.coeffRef(i,col) += alpha * tmp; 70 } 71 72 }; 73 74 // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? 75 // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators 76 // template<typename T1, typename T2/*, int _Options, typename _StrideType*/> 77 // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> > 78 // { 79 // enum { 80 // Defined = 1 81 // }; 82 // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; 83 // }; 84 85 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> 86 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> 87 { 88 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 89 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 90 typedef typename internal::remove_all<DenseResType>::type Res; 91 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 92 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 93 { 94 evaluator<Lhs> lhsEval(lhs); 95 for(Index c=0; c<rhs.cols(); ++c) 96 { 97 for(Index j=0; j<lhs.outerSize(); ++j) 98 { 99 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); 100 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); 101 for(LhsInnerIterator it(lhsEval,j); it ;++it) 102 res.coeffRef(it.index(),c) += it.value() * rhs_j; 103 } 104 } 105 } 106 }; 107 108 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 109 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> 110 { 111 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 112 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 113 typedef typename internal::remove_all<DenseResType>::type Res; 114 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 115 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 116 { 117 evaluator<Lhs> lhsEval(lhs); 118 for(Index j=0; j<lhs.outerSize(); ++j) 119 { 120 typename Res::RowXpr res_j(res.row(j)); 121 for(LhsInnerIterator it(lhsEval,j); it ;++it) 122 res_j += (alpha*it.value()) * rhs.row(it.index()); 123 } 124 } 125 }; 126 127 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 128 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> 129 { 130 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 131 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 132 typedef typename internal::remove_all<DenseResType>::type Res; 133 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 134 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 135 { 136 evaluator<Lhs> lhsEval(lhs); 137 for(Index j=0; j<lhs.outerSize(); ++j) 138 { 139 typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); 140 for(LhsInnerIterator it(lhsEval,j); it ;++it) 141 res.row(it.index()) += (alpha*it.value()) * rhs_j; 142 } 143 } 144 }; 145 146 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> 147 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 148 { 149 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); 150 } 151 152 } // end namespace internal 153 154 namespace internal { 155 156 template<typename Lhs, typename Rhs, int ProductType> 157 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 158 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > 159 { 160 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 161 162 template<typename Dest> 163 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 164 { 165 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; 166 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; 167 LhsNested lhsNested(lhs); 168 RhsNested rhsNested(rhs); 169 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); 170 } 171 }; 172 173 template<typename Lhs, typename Rhs, int ProductType> 174 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> 175 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 176 {}; 177 178 template<typename Lhs, typename Rhs, int ProductType> 179 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 180 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > 181 { 182 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 183 184 template<typename Dst> 185 static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 186 { 187 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; 188 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; 189 LhsNested lhsNested(lhs); 190 RhsNested rhsNested(rhs); 191 192 // transpose everything 193 Transpose<Dst> dstT(dst); 194 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); 195 } 196 }; 197 198 template<typename Lhs, typename Rhs, int ProductType> 199 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> 200 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 201 {}; 202 203 template<typename LhsT, typename RhsT, bool NeedToTranspose> 204 struct sparse_dense_outer_product_evaluator 205 { 206 protected: 207 typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; 208 typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; 209 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; 210 211 // if the actual left-hand side is a dense vector, 212 // then build a sparse-view so that we can seamlessly iterate over it. 213 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 214 Lhs1, SparseView<Lhs1> >::type ActualLhs; 215 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 216 Lhs1 const&, SparseView<Lhs1> >::type LhsArg; 217 218 typedef evaluator<ActualLhs> LhsEval; 219 typedef evaluator<ActualRhs> RhsEval; 220 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; 221 typedef typename ProdXprType::Scalar Scalar; 222 223 public: 224 enum { 225 Flags = NeedToTranspose ? RowMajorBit : 0, 226 CoeffReadCost = HugeCost 227 }; 228 229 class InnerIterator : public LhsIterator 230 { 231 public: 232 InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) 233 : LhsIterator(xprEval.m_lhsXprImpl, 0), 234 m_outer(outer), 235 m_empty(false), 236 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) 237 {} 238 239 EIGEN_STRONG_INLINE Index outer() const { return m_outer; } 240 EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } 241 EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } 242 243 EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } 244 EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } 245 246 protected: 247 Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const 248 { 249 return rhs.coeff(outer); 250 } 251 252 Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) 253 { 254 typename RhsEval::InnerIterator it(rhs, outer); 255 if (it && it.index()==0 && it.value()!=Scalar(0)) 256 return it.value(); 257 m_empty = true; 258 return Scalar(0); 259 } 260 261 Index m_outer; 262 bool m_empty; 263 Scalar m_factor; 264 }; 265 266 sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) 267 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 268 { 269 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 270 } 271 272 // transpose case 273 sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) 274 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 275 { 276 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 277 } 278 279 protected: 280 const LhsArg m_lhs; 281 evaluator<ActualLhs> m_lhsXprImpl; 282 evaluator<ActualRhs> m_rhsXprImpl; 283 }; 284 285 // sparse * dense outer product 286 template<typename Lhs, typename Rhs> 287 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> 288 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> 289 { 290 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; 291 292 typedef Product<Lhs, Rhs> XprType; 293 typedef typename XprType::PlainObject PlainObject; 294 295 explicit product_evaluator(const XprType& xpr) 296 : Base(xpr.lhs(), xpr.rhs()) 297 {} 298 299 }; 300 301 template<typename Lhs, typename Rhs> 302 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> 303 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> 304 { 305 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; 306 307 typedef Product<Lhs, Rhs> XprType; 308 typedef typename XprType::PlainObject PlainObject; 309 310 explicit product_evaluator(const XprType& xpr) 311 : Base(xpr.lhs(), xpr.rhs()) 312 {} 313 314 }; 315 316 } // end namespace internal 317 318 } // end namespace Eigen 319 320 #endif // EIGEN_SPARSEDENSEPRODUCT_H 321