• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
11 #define EIGEN_ITERATIVE_SOLVER_BASE_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename MatrixType>
18 struct is_ref_compatible_impl
19 {
20 private:
21   template <typename T0>
22   struct any_conversion
23   {
24     template <typename T> any_conversion(const volatile T&);
25     template <typename T> any_conversion(T&);
26   };
27   struct yes {int a[1];};
28   struct no  {int a[2];};
29 
30   template<typename T>
31   static yes test(const Ref<const T>&, int);
32   template<typename T>
33   static no  test(any_conversion<T>, ...);
34 
35 public:
36   static MatrixType ms_from;
37   enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) };
38 };
39 
40 template<typename MatrixType>
41 struct is_ref_compatible
42 {
43   enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value };
44 };
45 
46 template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value>
47 class generic_matrix_wrapper;
48 
49 // We have an explicit matrix at hand, compatible with Ref<>
50 template<typename MatrixType>
51 class generic_matrix_wrapper<MatrixType,false>
52 {
53 public:
54   typedef Ref<const MatrixType> ActualMatrixType;
55   template<int UpLo> struct ConstSelfAdjointViewReturnType {
56     typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
57   };
58 
59   enum {
60     MatrixFree = false
61   };
62 
generic_matrix_wrapper()63   generic_matrix_wrapper()
64     : m_dummy(0,0), m_matrix(m_dummy)
65   {}
66 
67   template<typename InputType>
generic_matrix_wrapper(const InputType & mat)68   generic_matrix_wrapper(const InputType &mat)
69     : m_matrix(mat)
70   {}
71 
matrix()72   const ActualMatrixType& matrix() const
73   {
74     return m_matrix;
75   }
76 
77   template<typename MatrixDerived>
grab(const EigenBase<MatrixDerived> & mat)78   void grab(const EigenBase<MatrixDerived> &mat)
79   {
80     m_matrix.~Ref<const MatrixType>();
81     ::new (&m_matrix) Ref<const MatrixType>(mat.derived());
82   }
83 
grab(const Ref<const MatrixType> & mat)84   void grab(const Ref<const MatrixType> &mat)
85   {
86     if(&(mat.derived()) != &m_matrix)
87     {
88       m_matrix.~Ref<const MatrixType>();
89       ::new (&m_matrix) Ref<const MatrixType>(mat);
90     }
91   }
92 
93 protected:
94   MatrixType m_dummy; // used to default initialize the Ref<> object
95   ActualMatrixType m_matrix;
96 };
97 
98 // MatrixType is not compatible with Ref<> -> matrix-free wrapper
99 template<typename MatrixType>
100 class generic_matrix_wrapper<MatrixType,true>
101 {
102 public:
103   typedef MatrixType ActualMatrixType;
104   template<int UpLo> struct ConstSelfAdjointViewReturnType
105   {
106     typedef ActualMatrixType Type;
107   };
108 
109   enum {
110     MatrixFree = true
111   };
112 
generic_matrix_wrapper()113   generic_matrix_wrapper()
114     : mp_matrix(0)
115   {}
116 
generic_matrix_wrapper(const MatrixType & mat)117   generic_matrix_wrapper(const MatrixType &mat)
118     : mp_matrix(&mat)
119   {}
120 
matrix()121   const ActualMatrixType& matrix() const
122   {
123     return *mp_matrix;
124   }
125 
grab(const MatrixType & mat)126   void grab(const MatrixType &mat)
127   {
128     mp_matrix = &mat;
129   }
130 
131 protected:
132   const ActualMatrixType *mp_matrix;
133 };
134 
135 }
136 
137 /** \ingroup IterativeLinearSolvers_Module
138   * \brief Base class for linear iterative solvers
139   *
140   * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner
141   */
142 template< typename Derived>
143 class IterativeSolverBase : public SparseSolverBase<Derived>
144 {
145 protected:
146   typedef SparseSolverBase<Derived> Base;
147   using Base::m_isInitialized;
148 
149 public:
150   typedef typename internal::traits<Derived>::MatrixType MatrixType;
151   typedef typename internal::traits<Derived>::Preconditioner Preconditioner;
152   typedef typename MatrixType::Scalar Scalar;
153   typedef typename MatrixType::StorageIndex StorageIndex;
154   typedef typename MatrixType::RealScalar RealScalar;
155 
156   enum {
157     ColsAtCompileTime = MatrixType::ColsAtCompileTime,
158     MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime
159   };
160 
161 public:
162 
163   using Base::derived;
164 
165   /** Default constructor. */
IterativeSolverBase()166   IterativeSolverBase()
167   {
168     init();
169   }
170 
171   /** Initialize the solver with matrix \a A for further \c Ax=b solving.
172     *
173     * This constructor is a shortcut for the default constructor followed
174     * by a call to compute().
175     *
176     * \warning this class stores a reference to the matrix A as well as some
177     * precomputed values that depend on it. Therefore, if \a A is changed
178     * this class becomes invalid. Call compute() to update it with the new
179     * matrix A, or modify a copy of A.
180     */
181   template<typename MatrixDerived>
IterativeSolverBase(const EigenBase<MatrixDerived> & A)182   explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A)
183     : m_matrixWrapper(A.derived())
184   {
185     init();
186     compute(matrix());
187   }
188 
~IterativeSolverBase()189   ~IterativeSolverBase() {}
190 
191   /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems.
192     *
193     * Currently, this function mostly calls analyzePattern on the preconditioner. In the future
194     * we might, for instance, implement column reordering for faster matrix vector products.
195     */
196   template<typename MatrixDerived>
analyzePattern(const EigenBase<MatrixDerived> & A)197   Derived& analyzePattern(const EigenBase<MatrixDerived>& A)
198   {
199     grab(A.derived());
200     m_preconditioner.analyzePattern(matrix());
201     m_isInitialized = true;
202     m_analysisIsOk = true;
203     m_info = m_preconditioner.info();
204     return derived();
205   }
206 
207   /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b problems.
208     *
209     * Currently, this function mostly calls factorize on the preconditioner.
210     *
211     * \warning this class stores a reference to the matrix A as well as some
212     * precomputed values that depend on it. Therefore, if \a A is changed
213     * this class becomes invalid. Call compute() to update it with the new
214     * matrix A, or modify a copy of A.
215     */
216   template<typename MatrixDerived>
factorize(const EigenBase<MatrixDerived> & A)217   Derived& factorize(const EigenBase<MatrixDerived>& A)
218   {
219     eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
220     grab(A.derived());
221     m_preconditioner.factorize(matrix());
222     m_factorizationIsOk = true;
223     m_info = m_preconditioner.info();
224     return derived();
225   }
226 
227   /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems.
228     *
229     * Currently, this function mostly initializes/computes the preconditioner. In the future
230     * we might, for instance, implement column reordering for faster matrix vector products.
231     *
232     * \warning this class stores a reference to the matrix A as well as some
233     * precomputed values that depend on it. Therefore, if \a A is changed
234     * this class becomes invalid. Call compute() to update it with the new
235     * matrix A, or modify a copy of A.
236     */
237   template<typename MatrixDerived>
compute(const EigenBase<MatrixDerived> & A)238   Derived& compute(const EigenBase<MatrixDerived>& A)
239   {
240     grab(A.derived());
241     m_preconditioner.compute(matrix());
242     m_isInitialized = true;
243     m_analysisIsOk = true;
244     m_factorizationIsOk = true;
245     m_info = m_preconditioner.info();
246     return derived();
247   }
248 
249   /** \internal */
rows()250   EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }
251 
252   /** \internal */
cols()253   EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }
254 
255   /** \returns the tolerance threshold used by the stopping criteria.
256     * \sa setTolerance()
257     */
tolerance()258   RealScalar tolerance() const { return m_tolerance; }
259 
260   /** Sets the tolerance threshold used by the stopping criteria.
261     *
262     * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|.
263     * The default value is the machine precision given by NumTraits<Scalar>::epsilon()
264     */
setTolerance(const RealScalar & tolerance)265   Derived& setTolerance(const RealScalar& tolerance)
266   {
267     m_tolerance = tolerance;
268     return derived();
269   }
270 
271   /** \returns a read-write reference to the preconditioner for custom configuration. */
preconditioner()272   Preconditioner& preconditioner() { return m_preconditioner; }
273 
274   /** \returns a read-only reference to the preconditioner. */
preconditioner()275   const Preconditioner& preconditioner() const { return m_preconditioner; }
276 
277   /** \returns the max number of iterations.
278     * It is either the value set by setMaxIterations or, by default,
279     * twice the number of columns of the matrix.
280     */
maxIterations()281   Index maxIterations() const
282   {
283     return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations;
284   }
285 
286   /** Sets the max number of iterations.
287     * Default is twice the number of columns of the matrix.
288     */
setMaxIterations(Index maxIters)289   Derived& setMaxIterations(Index maxIters)
290   {
291     m_maxIterations = maxIters;
292     return derived();
293   }
294 
295   /** \returns the number of iterations performed during the last solve */
iterations()296   Index iterations() const
297   {
298     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
299     return m_iterations;
300   }
301 
302   /** \returns the tolerance error reached during the last solve.
303     * It is a close approximation of the true relative residual error |Ax-b|/|b|.
304     */
error()305   RealScalar error() const
306   {
307     eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
308     return m_error;
309   }
310 
311   /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A
312     * and \a x0 as an initial solution.
313     *
314     * \sa solve(), compute()
315     */
316   template<typename Rhs,typename Guess>
317   inline const SolveWithGuess<Derived, Rhs, Guess>
solveWithGuess(const MatrixBase<Rhs> & b,const Guess & x0)318   solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
319   {
320     eigen_assert(m_isInitialized && "Solver is not initialized.");
321     eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
322     return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
323   }
324 
325   /** \returns Success if the iterations converged, and NoConvergence otherwise. */
info()326   ComputationInfo info() const
327   {
328     eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
329     return m_info;
330   }
331 
332   /** \internal */
333   template<typename Rhs, typename DestDerived>
_solve_with_guess_impl(const Rhs & b,SparseMatrixBase<DestDerived> & aDest)334   void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
335   {
336     eigen_assert(rows()==b.rows());
337 
338     Index rhsCols = b.cols();
339     Index size = b.rows();
340     DestDerived& dest(aDest.derived());
341     typedef typename DestDerived::Scalar DestScalar;
342     Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
343     Eigen::Matrix<DestScalar,Dynamic,1> tx(cols());
344     // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
345     // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
346     typename DestDerived::PlainObject tmp(cols(),rhsCols);
347     ComputationInfo global_info = Success;
348     for(Index k=0; k<rhsCols; ++k)
349     {
350       tb = b.col(k);
351       tx = dest.col(k);
352       derived()._solve_vector_with_guess_impl(tb,tx);
353       tmp.col(k) = tx.sparseView(0);
354 
355       // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
356       // we need to restore it to the worst value.
357       if(m_info==NumericalIssue)
358         global_info = NumericalIssue;
359       else if(m_info==NoConvergence)
360         global_info = NoConvergence;
361     }
362     m_info = global_info;
363     dest.swap(tmp);
364   }
365 
366   template<typename Rhs, typename DestDerived>
367   typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
_solve_with_guess_impl(const Rhs & b,MatrixBase<DestDerived> & aDest)368   _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
369   {
370     eigen_assert(rows()==b.rows());
371 
372     Index rhsCols = b.cols();
373     DestDerived& dest(aDest.derived());
374     ComputationInfo global_info = Success;
375     for(Index k=0; k<rhsCols; ++k)
376     {
377       typename DestDerived::ColXpr xk(dest,k);
378       typename Rhs::ConstColXpr bk(b,k);
379       derived()._solve_vector_with_guess_impl(bk,xk);
380 
381       // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
382       // we need to restore it to the worst value.
383       if(m_info==NumericalIssue)
384         global_info = NumericalIssue;
385       else if(m_info==NoConvergence)
386         global_info = NoConvergence;
387     }
388     m_info = global_info;
389   }
390 
391   template<typename Rhs, typename DestDerived>
392   typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
_solve_with_guess_impl(const Rhs & b,MatrixBase<DestDerived> & dest)393   _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
394   {
395     derived()._solve_vector_with_guess_impl(b,dest.derived());
396   }
397 
398   /** \internal default initial guess = 0 */
399   template<typename Rhs,typename Dest>
_solve_impl(const Rhs & b,Dest & x)400   void _solve_impl(const Rhs& b, Dest& x) const
401   {
402     x.setZero();
403     derived()._solve_with_guess_impl(b,x);
404   }
405 
406 protected:
init()407   void init()
408   {
409     m_isInitialized = false;
410     m_analysisIsOk = false;
411     m_factorizationIsOk = false;
412     m_maxIterations = -1;
413     m_tolerance = NumTraits<Scalar>::epsilon();
414   }
415 
416   typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper;
417   typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
418 
matrix()419   const ActualMatrixType& matrix() const
420   {
421     return m_matrixWrapper.matrix();
422   }
423 
424   template<typename InputType>
grab(const InputType & A)425   void grab(const InputType &A)
426   {
427     m_matrixWrapper.grab(A);
428   }
429 
430   MatrixWrapper m_matrixWrapper;
431   Preconditioner m_preconditioner;
432 
433   Index m_maxIterations;
434   RealScalar m_tolerance;
435 
436   mutable RealScalar m_error;
437   mutable Index m_iterations;
438   mutable ComputationInfo m_info;
439   mutable bool m_analysisIsOk, m_factorizationIsOk;
440 };
441 
442 } // end namespace Eigen
443 
444 #endif // EIGEN_ITERATIVE_SOLVER_BASE_H
445