• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11 #define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
12 
13 namespace Eigen {
14 
15 // The product of a diagonal matrix with a sparse matrix can be easily
16 // implemented using expression template.
17 // We have two consider very different cases:
18 // 1 - diag * row-major sparse
19 //     => each inner vector <=> scalar * sparse vector product
20 //     => so we can reuse CwiseUnaryOp::InnerIterator
21 // 2 - diag * col-major sparse
22 //     => each inner vector <=> densevector * sparse vector cwise product
23 //     => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
24 //        for that particular case
25 // The two other cases are symmetric.
26 
27 namespace internal {
28 
29 enum {
30   SDP_AsScalarProduct,
31   SDP_AsCwiseProduct
32 };
33 
34 template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
35 struct sparse_diagonal_product_evaluator;
36 
37 template<typename Lhs, typename Rhs, int ProductTag>
38 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DiagonalShape, SparseShape>
39   : public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
40 {
41   typedef Product<Lhs, Rhs, DefaultProduct> XprType;
42   enum { CoeffReadCost = HugeCost, Flags = Rhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags
43 
44   typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
45   explicit product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
46 };
47 
48 template<typename Lhs, typename Rhs, int ProductTag>
49 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, DiagonalShape>
50   : public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
51 {
52   typedef Product<Lhs, Rhs, DefaultProduct> XprType;
53   enum { CoeffReadCost = HugeCost, Flags = Lhs::Flags&RowMajorBit, Alignment = 0 }; // FIXME CoeffReadCost & Flags
54 
55   typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
56   explicit product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal().transpose()) {}
57 };
58 
59 template<typename SparseXprType, typename DiagonalCoeffType>
60 struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
61 {
62 protected:
63   typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
64   typedef typename SparseXprType::Scalar Scalar;
65 
66 public:
67   class InnerIterator : public SparseXprInnerIterator
68   {
69   public:
70     InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
71       : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
72         m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
73     {}
74 
75     EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
76   protected:
77     typename DiagonalCoeffType::Scalar m_coeff;
78   };
79 
80   sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
81     : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
82   {}
83 
84   Index nonZerosEstimate() const { return m_sparseXprImpl.nonZerosEstimate(); }
85 
86 protected:
87   evaluator<SparseXprType> m_sparseXprImpl;
88   evaluator<DiagonalCoeffType> m_diagCoeffImpl;
89 };
90 
91 
92 template<typename SparseXprType, typename DiagCoeffType>
93 struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
94 {
95   typedef typename SparseXprType::Scalar Scalar;
96   typedef typename SparseXprType::StorageIndex StorageIndex;
97 
98   typedef typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
99                                                                        : SparseXprType::ColsAtCompileTime>::type DiagCoeffNested;
100 
101   class InnerIterator
102   {
103     typedef typename evaluator<SparseXprType>::InnerIterator SparseXprIter;
104   public:
105     InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
106       : m_sparseIter(xprEval.m_sparseXprEval, outer), m_diagCoeffNested(xprEval.m_diagCoeffNested)
107     {}
108 
109     inline Scalar value() const { return m_sparseIter.value() * m_diagCoeffNested.coeff(index()); }
110     inline StorageIndex index() const  { return m_sparseIter.index(); }
111     inline Index outer() const  { return m_sparseIter.outer(); }
112     inline Index col() const    { return SparseXprType::IsRowMajor ? m_sparseIter.index() : m_sparseIter.outer(); }
113     inline Index row() const    { return SparseXprType::IsRowMajor ? m_sparseIter.outer() : m_sparseIter.index(); }
114 
115     EIGEN_STRONG_INLINE InnerIterator& operator++() { ++m_sparseIter; return *this; }
116     inline operator bool() const  { return m_sparseIter; }
117 
118   protected:
119     SparseXprIter m_sparseIter;
120     DiagCoeffNested m_diagCoeffNested;
121   };
122 
123   sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
124     : m_sparseXprEval(sparseXpr), m_diagCoeffNested(diagCoeff)
125   {}
126 
127   Index nonZerosEstimate() const { return m_sparseXprEval.nonZerosEstimate(); }
128 
129 protected:
130   evaluator<SparseXprType> m_sparseXprEval;
131   DiagCoeffNested m_diagCoeffNested;
132 };
133 
134 } // end namespace internal
135 
136 } // end namespace Eigen
137 
138 #endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
139