1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H 11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H 12 13 namespace Eigen { 14 15 // The product of a diagonal matrix with a sparse matrix can be easily 16 // implemented using expression template. 17 // We have two consider very different cases: 18 // 1 - diag * row-major sparse 19 // => each inner vector <=> scalar * sparse vector product 20 // => so we can reuse CwiseUnaryOp::InnerIterator 21 // 2 - diag * col-major sparse 22 // => each inner vector <=> densevector * sparse vector cwise product 23 // => again, we can reuse specialization of CwiseBinaryOp::InnerIterator 24 // for that particular case 25 // The two other cases are symmetric. 26 27 namespace internal { 28 29 enum { 30 SDP_AsScalarProduct, 31 SDP_AsCwiseProduct 32 }; 33 34 template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag> 35 struct sparse_diagonal_product_evaluator; 36 37 template<typename Lhs, typename Rhs, int ProductTag> 38 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape> 39 : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> 40 { 41 typedef Product<Lhs, Rhs, DefaultProduct> XprType; 42 enum { CoeffReadCost = HugeCost, Flags = Rhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags 43 44 typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base; 45 explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} 46 }; 47 48 template<typename Lhs, typename Rhs, int ProductTag> 49 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape> 50 : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> 51 { 52 typedef Product<Lhs, Rhs, DefaultProduct> XprType; 53 enum { CoeffReadCost = HugeCost, Flags = Lhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags 54 55 typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; 56 explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {} 57 }; 58 59 template<typename SparseXprType, typename DiagonalCoeffType> 60 struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct> 61 { 62 protected: 63 typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator; 64 typedef typename SparseXprType::Scalar Scalar; 65 66 public: 67 class InnerIterator : public SparseXprInnerIterator 68 { 69 public: 70 InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) 71 : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), 72 m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) 73 {} 74 75 EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } 76 protected: 77 typename DiagonalCoeffType::Scalar m_coeff; 78 }; 79 80 sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) 81 : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) 82 {} 83 84 Index nonZerosEstimate() const { return m_sparseXprImpl.nonZerosEstimate(); } 85 86 protected: 87 evaluator<SparseXprType> m_sparseXprImpl; 88 evaluator<DiagonalCoeffType> m_diagCoeffImpl; 89 }; 90 91 92 template<typename SparseXprType, typename DiagCoeffType> 93 struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct> 94 { 95 typedef typename SparseXprType::Scalar Scalar; 96 typedef typename SparseXprType::StorageIndex StorageIndex; 97 98 typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime 99 : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested; 100 101 class InnerIterator 102 { 103 typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter; 104 public: 105 InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) 106 : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested) 107 {} 108 109 inline Scalar value() const { return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); } 110 inline StorageIndex index() const { return m_sparseIter.index(); } 111 inline Index outer() const { return m_sparseIter.outer(); } 112 inline Index col() const { return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); } 113 inline Index row() const { return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); } 114 115 EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_sparseIter; return *this; } 116 inline operator bool() const { return m_sparseIter; } 117 118 protected: 119 SparseXprIter m_sparseIter; 120 DiagCoeffNested m_diagCoeffNested; 121 }; 122 123 sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) 124 : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff) 125 {} 126 127 Index nonZerosEstimate() const { return m_sparseXprEval.nonZerosEstimate(); } 128 129 protected: 130 evaluator<SparseXprType> m_sparseXprEval; 131 DiagCoeffNested m_diagCoeffNested; 132 }; 133 134 } // end namespace internal 135 136 } // end namespace Eigen 137 138 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H 139