• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 
18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
19 template<typename Lhs, typename Rhs, typename ResultType>
sparse_sparse_product_with_pruning_impl(const Lhs & lhs,const Rhs & rhs,ResultType & res,const typename ResultType::RealScalar & tolerance)20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
21 {
22   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
23 
24   typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
25   typedef typename remove_all<ResultType>::type::Scalar ResScalar;
26   typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
27 
28   // make sure to call innerSize/outerSize since we fake the storage order.
29   Index rows = lhs.innerSize();
30   Index cols = rhs.outerSize();
31   //Index size = lhs.outerSize();
32   eigen_assert(lhs.outerSize() == rhs.innerSize());
33 
34   // allocate a temporary buffer
35   AmbiVector<ResScalar,StorageIndex> tempVector(rows);
36 
37   // mimics a resizeByInnerOuter:
38   if(ResultType::IsRowMajor)
39     res.resize(cols, rows);
40   else
41     res.resize(rows, cols);
42 
43   evaluator<Lhs> lhsEval(lhs);
44   evaluator<Rhs> rhsEval(rhs);
45 
46   // estimate the number of non zero entries
47   // given a rhs column containing Y non zeros, we assume that the respective Y columns
48   // of the lhs differs in average of one non zeros, thus the number of non zeros for
49   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
50   // per column of the lhs.
51   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
52   Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
53 
54   res.reserve(estimated_nnz_prod);
55   double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
56   for (Index j=0; j<cols; ++j)
57   {
58     // FIXME:
59     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
60     // let's do a more accurate determination of the nnz ratio for the current column j of res
61     tempVector.init(ratioColRes);
62     tempVector.setZero();
63     for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
64     {
65       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
66       tempVector.restart();
67       RhsScalar x = rhsIt.value();
68       for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
69       {
70         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
71       }
72     }
73     res.startVec(j);
74     for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
75       res.insertBackByOuterInner(j,it.index()) = it.value();
76   }
77   res.finalize();
78 }
79 
80 template<typename Lhs, typename Rhs, typename ResultType,
81   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
82   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
83   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
84 struct sparse_sparse_product_with_pruning_selector;
85 
86 template<typename Lhs, typename Rhs, typename ResultType>
87 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
88 {
89   typedef typename ResultType::RealScalar RealScalar;
90 
91   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
92   {
93     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
94     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
95     res.swap(_res);
96   }
97 };
98 
99 template<typename Lhs, typename Rhs, typename ResultType>
100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
101 {
102   typedef typename ResultType::RealScalar RealScalar;
103   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
104   {
105     // we need a col-major matrix to hold the result
106     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
107     SparseTemporaryType _res(res.rows(), res.cols());
108     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
109     res = _res;
110   }
111 };
112 
113 template<typename Lhs, typename Rhs, typename ResultType>
114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
115 {
116   typedef typename ResultType::RealScalar RealScalar;
117   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
118   {
119     // let's transpose the product to get a column x column product
120     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
121     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
122     res.swap(_res);
123   }
124 };
125 
126 template<typename Lhs, typename Rhs, typename ResultType>
127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
128 {
129   typedef typename ResultType::RealScalar RealScalar;
130   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
131   {
132     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
133     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
134     ColMajorMatrixLhs colLhs(lhs);
135     ColMajorMatrixRhs colRhs(rhs);
136     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
137 
138     // let's transpose the product to get a column x column product
139 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
140 //     SparseTemporaryType _res(res.cols(), res.rows());
141 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
142 //     res = _res.transpose();
143   }
144 };
145 
146 template<typename Lhs, typename Rhs, typename ResultType>
147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
148 {
149   typedef typename ResultType::RealScalar RealScalar;
150   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
151   {
152     typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
153     RowMajorMatrixLhs rowLhs(lhs);
154     sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
155   }
156 };
157 
158 template<typename Lhs, typename Rhs, typename ResultType>
159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
160 {
161   typedef typename ResultType::RealScalar RealScalar;
162   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
163   {
164     typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
165     RowMajorMatrixRhs rowRhs(rhs);
166     sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
167   }
168 };
169 
170 template<typename Lhs, typename Rhs, typename ResultType>
171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
172 {
173   typedef typename ResultType::RealScalar RealScalar;
174   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
175   {
176     typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
177     ColMajorMatrixRhs colRhs(rhs);
178     internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
179   }
180 };
181 
182 template<typename Lhs, typename Rhs, typename ResultType>
183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
184 {
185   typedef typename ResultType::RealScalar RealScalar;
186   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
187   {
188     typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
189     ColMajorMatrixLhs colLhs(lhs);
190     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
191   }
192 };
193 
194 } // end namespace internal
195 
196 } // end namespace Eigen
197 
198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
199