• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSETRIANGULARSOLVER_H
11 #define EIGEN_SPARSETRIANGULARSOLVER_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Lhs, typename Rhs, int Mode,
18   int UpLo = (Mode & Lower)
19            ? Lower
20            : (Mode & Upper)
21            ? Upper
22            : -1,
23   int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
24 struct sparse_solve_triangular_selector;
25 
26 // forward substitution, row-major
27 template<typename Lhs, typename Rhs, int Mode>
28 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,RowMajor>
29 {
30   typedef typename Rhs::Scalar Scalar;
31   typedef evaluator<Lhs> LhsEval;
32   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
33   static void run(const Lhs& lhs, Rhs& other)
34   {
35     LhsEval lhsEval(lhs);
36     for(Index col=0 ; col<other.cols() ; ++col)
37     {
38       for(Index i=0; i<lhs.rows(); ++i)
39       {
40         Scalar tmp = other.coeff(i,col);
41         Scalar lastVal(0);
42         Index lastIndex = 0;
43         for(LhsIterator it(lhsEval, i); it; ++it)
44         {
45           lastVal = it.value();
46           lastIndex = it.index();
47           if(lastIndex==i)
48             break;
49           tmp -= lastVal * other.coeff(lastIndex,col);
50         }
51         if (Mode & UnitDiag)
52           other.coeffRef(i,col) = tmp;
53         else
54         {
55           eigen_assert(lastIndex==i);
56           other.coeffRef(i,col) = tmp/lastVal;
57         }
58       }
59     }
60   }
61 };
62 
63 // backward substitution, row-major
64 template<typename Lhs, typename Rhs, int Mode>
65 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,RowMajor>
66 {
67   typedef typename Rhs::Scalar Scalar;
68   typedef evaluator<Lhs> LhsEval;
69   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
70   static void run(const Lhs& lhs, Rhs& other)
71   {
72     LhsEval lhsEval(lhs);
73     for(Index col=0 ; col<other.cols() ; ++col)
74     {
75       for(Index i=lhs.rows()-1 ; i>=0 ; --i)
76       {
77         Scalar tmp = other.coeff(i,col);
78         Scalar l_ii(0);
79         LhsIterator it(lhsEval, i);
80         while(it && it.index()<i)
81           ++it;
82         if(!(Mode & UnitDiag))
83         {
84           eigen_assert(it && it.index()==i);
85           l_ii = it.value();
86           ++it;
87         }
88         else if (it && it.index() == i)
89           ++it;
90         for(; it; ++it)
91         {
92           tmp -= it.value() * other.coeff(it.index(),col);
93         }
94 
95         if (Mode & UnitDiag)  other.coeffRef(i,col) = tmp;
96         else                  other.coeffRef(i,col) = tmp/l_ii;
97       }
98     }
99   }
100 };
101 
102 // forward substitution, col-major
103 template<typename Lhs, typename Rhs, int Mode>
104 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Lower,ColMajor>
105 {
106   typedef typename Rhs::Scalar Scalar;
107   typedef evaluator<Lhs> LhsEval;
108   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
109   static void run(const Lhs& lhs, Rhs& other)
110   {
111     LhsEval lhsEval(lhs);
112     for(Index col=0 ; col<other.cols() ; ++col)
113     {
114       for(Index i=0; i<lhs.cols(); ++i)
115       {
116         Scalar& tmp = other.coeffRef(i,col);
117         if (tmp!=Scalar(0)) // optimization when other is actually sparse
118         {
119           LhsIterator it(lhsEval, i);
120           while(it && it.index()<i)
121             ++it;
122           if(!(Mode & UnitDiag))
123           {
124             eigen_assert(it && it.index()==i);
125             tmp /= it.value();
126           }
127           if (it && it.index()==i)
128             ++it;
129           for(; it; ++it)
130             other.coeffRef(it.index(), col) -= tmp * it.value();
131         }
132       }
133     }
134   }
135 };
136 
137 // backward substitution, col-major
138 template<typename Lhs, typename Rhs, int Mode>
139 struct sparse_solve_triangular_selector<Lhs,Rhs,Mode,Upper,ColMajor>
140 {
141   typedef typename Rhs::Scalar Scalar;
142   typedef evaluator<Lhs> LhsEval;
143   typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
144   static void run(const Lhs& lhs, Rhs& other)
145   {
146     LhsEval lhsEval(lhs);
147     for(Index col=0 ; col<other.cols() ; ++col)
148     {
149       for(Index i=lhs.cols()-1; i>=0; --i)
150       {
151         Scalar& tmp = other.coeffRef(i,col);
152         if (tmp!=Scalar(0)) // optimization when other is actually sparse
153         {
154           if(!(Mode & UnitDiag))
155           {
156             // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
157             LhsIterator it(lhsEval, i);
158             while(it && it.index()!=i)
159               ++it;
160             eigen_assert(it && it.index()==i);
161             other.coeffRef(i,col) /= it.value();
162           }
163           LhsIterator it(lhsEval, i);
164           for(; it && it.index()<i; ++it)
165             other.coeffRef(it.index(), col) -= tmp * it.value();
166         }
167       }
168     }
169   }
170 };
171 
172 } // end namespace internal
173 
174 #ifndef EIGEN_PARSED_BY_DOXYGEN
175 
176 template<typename ExpressionType,unsigned int Mode>
177 template<typename OtherDerived>
178 void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const
179 {
180   eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
181   eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
182 
183   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
184 
185   typedef typename internal::conditional<copy,
186     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
187   OtherCopy otherCopy(other.derived());
188 
189   internal::sparse_solve_triangular_selector<ExpressionType, typename internal::remove_reference<OtherCopy>::type, Mode>::run(derived().nestedExpression(), otherCopy);
190 
191   if (copy)
192     other = otherCopy;
193 }
194 #endif
195 
196 // pure sparse path
197 
198 namespace internal {
199 
200 template<typename Lhs, typename Rhs, int Mode,
201   int UpLo = (Mode & Lower)
202            ? Lower
203            : (Mode & Upper)
204            ? Upper
205            : -1,
206   int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
207 struct sparse_solve_triangular_sparse_selector;
208 
209 // forward substitution, col-major
210 template<typename Lhs, typename Rhs, int Mode, int UpLo>
211 struct sparse_solve_triangular_sparse_selector<Lhs,Rhs,Mode,UpLo,ColMajor>
212 {
213   typedef typename Rhs::Scalar Scalar;
214   typedef typename promote_index_type<typename traits<Lhs>::StorageIndex,
215                                       typename traits<Rhs>::StorageIndex>::type StorageIndex;
216   static void run(const Lhs& lhs, Rhs& other)
217   {
218     const bool IsLower = (UpLo==Lower);
219     AmbiVector<Scalar,StorageIndex> tempVector(other.rows()*2);
220     tempVector.setBounds(0,other.rows());
221 
222     Rhs res(other.rows(), other.cols());
223     res.reserve(other.nonZeros());
224 
225     for(Index col=0 ; col<other.cols() ; ++col)
226     {
227       // FIXME estimate number of non zeros
228       tempVector.init(.99/*float(other.col(col).nonZeros())/float(other.rows())*/);
229       tempVector.setZero();
230       tempVector.restart();
231       for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt)
232       {
233         tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
234       }
235 
236       for(Index i=IsLower?0:lhs.cols()-1;
237           IsLower?i<lhs.cols():i>=0;
238           i+=IsLower?1:-1)
239       {
240         tempVector.restart();
241         Scalar& ci = tempVector.coeffRef(i);
242         if (ci!=Scalar(0))
243         {
244           // find
245           typename Lhs::InnerIterator it(lhs, i);
246           if(!(Mode & UnitDiag))
247           {
248             if (IsLower)
249             {
250               eigen_assert(it.index()==i);
251               ci /= it.value();
252             }
253             else
254               ci /= lhs.coeff(i,i);
255           }
256           tempVector.restart();
257           if (IsLower)
258           {
259             if (it.index()==i)
260               ++it;
261             for(; it; ++it)
262               tempVector.coeffRef(it.index()) -= ci * it.value();
263           }
264           else
265           {
266             for(; it && it.index()<i; ++it)
267               tempVector.coeffRef(it.index()) -= ci * it.value();
268           }
269         }
270       }
271 
272 
273       Index count = 0;
274       // FIXME compute a reference value to filter zeros
275       for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector/*,1e-12*/); it; ++it)
276       {
277         ++ count;
278 //         std::cerr << "fill " << it.index() << ", " << col << "\n";
279 //         std::cout << it.value() << "  ";
280         // FIXME use insertBack
281         res.insert(it.index(), col) = it.value();
282       }
283 //       std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
284     }
285     res.finalize();
286     other = res.markAsRValue();
287   }
288 };
289 
290 } // end namespace internal
291 
292 #ifndef EIGEN_PARSED_BY_DOXYGEN
293 template<typename ExpressionType,unsigned int Mode>
294 template<typename OtherDerived>
295 void TriangularViewImpl<ExpressionType,Mode,Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const
296 {
297   eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
298   eigen_assert( (!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower)));
299 
300 //   enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
301 
302 //   typedef typename internal::conditional<copy,
303 //     typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy;
304 //   OtherCopy otherCopy(other.derived());
305 
306   internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(derived().nestedExpression(), other.derived());
307 
308 //   if (copy)
309 //     other = otherCopy;
310 }
311 #endif
312 
313 } // end namespace Eigen
314 
315 #endif // EIGEN_SPARSETRIANGULARSOLVER_H
316