1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H 11 #define EIGEN_ITERATIVE_SOLVER_BASE_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename MatrixType> 18 struct is_ref_compatible_impl 19 { 20 private: 21 template <typename T0> 22 struct any_conversion 23 { 24 template <typename T> any_conversion(const volatile T&); 25 template <typename T> any_conversion(T&); 26 }; 27 struct yes {int a[1];}; 28 struct no {int a[2];}; 29 30 template<typename T> 31 static yes test(const Ref<const T>&, int); 32 template<typename T> 33 static no test(any_conversion<T>, ...); 34 35 public: 36 static MatrixType ms_from; 37 enum { value = sizeof(test<MatrixType>(ms_from, 0))==sizeof(yes) }; 38 }; 39 40 template<typename MatrixType> 41 struct is_ref_compatible 42 { 43 enum { value = is_ref_compatible_impl<typename remove_all<MatrixType>::type>::value }; 44 }; 45 46 template<typename MatrixType, bool MatrixFree = !internal::is_ref_compatible<MatrixType>::value> 47 class generic_matrix_wrapper; 48 49 // We have an explicit matrix at hand, compatible with Ref<> 50 template<typename MatrixType> 51 class generic_matrix_wrapper<MatrixType,false> 52 { 53 public: 54 typedef Ref<const MatrixType> ActualMatrixType; 55 template<int UpLo> struct ConstSelfAdjointViewReturnType { 56 typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type; 57 }; 58 59 enum { 60 MatrixFree = false 61 }; 62 generic_matrix_wrapper()63 generic_matrix_wrapper() 64 : m_dummy(0,0), m_matrix(m_dummy) 65 {} 66 67 template<typename InputType> generic_matrix_wrapper(const InputType & mat)68 generic_matrix_wrapper(const InputType &mat) 69 : m_matrix(mat) 70 {} 71 matrix()72 const ActualMatrixType& matrix() const 73 { 74 return m_matrix; 75 } 76 77 template<typename MatrixDerived> grab(const EigenBase<MatrixDerived> & mat)78 void grab(const EigenBase<MatrixDerived> &mat) 79 { 80 m_matrix.~Ref<const MatrixType>(); 81 ::new (&m_matrix) Ref<const MatrixType>(mat.derived()); 82 } 83 grab(const Ref<const MatrixType> & mat)84 void grab(const Ref<const MatrixType> &mat) 85 { 86 if(&(mat.derived()) != &m_matrix) 87 { 88 m_matrix.~Ref<const MatrixType>(); 89 ::new (&m_matrix) Ref<const MatrixType>(mat); 90 } 91 } 92 93 protected: 94 MatrixType m_dummy; // used to default initialize the Ref<> object 95 ActualMatrixType m_matrix; 96 }; 97 98 // MatrixType is not compatible with Ref<> -> matrix-free wrapper 99 template<typename MatrixType> 100 class generic_matrix_wrapper<MatrixType,true> 101 { 102 public: 103 typedef MatrixType ActualMatrixType; 104 template<int UpLo> struct ConstSelfAdjointViewReturnType 105 { 106 typedef ActualMatrixType Type; 107 }; 108 109 enum { 110 MatrixFree = true 111 }; 112 generic_matrix_wrapper()113 generic_matrix_wrapper() 114 : mp_matrix(0) 115 {} 116 generic_matrix_wrapper(const MatrixType & mat)117 generic_matrix_wrapper(const MatrixType &mat) 118 : mp_matrix(&mat) 119 {} 120 matrix()121 const ActualMatrixType& matrix() const 122 { 123 return *mp_matrix; 124 } 125 grab(const MatrixType & mat)126 void grab(const MatrixType &mat) 127 { 128 mp_matrix = &mat; 129 } 130 131 protected: 132 const ActualMatrixType *mp_matrix; 133 }; 134 135 } 136 137 /** \ingroup IterativeLinearSolvers_Module 138 * \brief Base class for linear iterative solvers 139 * 140 * \sa class SimplicialCholesky, DiagonalPreconditioner, IdentityPreconditioner 141 */ 142 template< typename Derived> 143 class IterativeSolverBase : public SparseSolverBase<Derived> 144 { 145 protected: 146 typedef SparseSolverBase<Derived> Base; 147 using Base::m_isInitialized; 148 149 public: 150 typedef typename internal::traits<Derived>::MatrixType MatrixType; 151 typedef typename internal::traits<Derived>::Preconditioner Preconditioner; 152 typedef typename MatrixType::Scalar Scalar; 153 typedef typename MatrixType::StorageIndex StorageIndex; 154 typedef typename MatrixType::RealScalar RealScalar; 155 156 enum { 157 ColsAtCompileTime = MatrixType::ColsAtCompileTime, 158 MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime 159 }; 160 161 public: 162 163 using Base::derived; 164 165 /** Default constructor. */ IterativeSolverBase()166 IterativeSolverBase() 167 { 168 init(); 169 } 170 171 /** Initialize the solver with matrix \a A for further \c Ax=b solving. 172 * 173 * This constructor is a shortcut for the default constructor followed 174 * by a call to compute(). 175 * 176 * \warning this class stores a reference to the matrix A as well as some 177 * precomputed values that depend on it. Therefore, if \a A is changed 178 * this class becomes invalid. Call compute() to update it with the new 179 * matrix A, or modify a copy of A. 180 */ 181 template<typename MatrixDerived> IterativeSolverBase(const EigenBase<MatrixDerived> & A)182 explicit IterativeSolverBase(const EigenBase<MatrixDerived>& A) 183 : m_matrixWrapper(A.derived()) 184 { 185 init(); 186 compute(matrix()); 187 } 188 ~IterativeSolverBase()189 ~IterativeSolverBase() {} 190 191 /** Initializes the iterative solver for the sparsity pattern of the matrix \a A for further solving \c Ax=b problems. 192 * 193 * Currently, this function mostly calls analyzePattern on the preconditioner. In the future 194 * we might, for instance, implement column reordering for faster matrix vector products. 195 */ 196 template<typename MatrixDerived> analyzePattern(const EigenBase<MatrixDerived> & A)197 Derived& analyzePattern(const EigenBase<MatrixDerived>& A) 198 { 199 grab(A.derived()); 200 m_preconditioner.analyzePattern(matrix()); 201 m_isInitialized = true; 202 m_analysisIsOk = true; 203 m_info = m_preconditioner.info(); 204 return derived(); 205 } 206 207 /** Initializes the iterative solver with the numerical values of the matrix \a A for further solving \c Ax=b problems. 208 * 209 * Currently, this function mostly calls factorize on the preconditioner. 210 * 211 * \warning this class stores a reference to the matrix A as well as some 212 * precomputed values that depend on it. Therefore, if \a A is changed 213 * this class becomes invalid. Call compute() to update it with the new 214 * matrix A, or modify a copy of A. 215 */ 216 template<typename MatrixDerived> factorize(const EigenBase<MatrixDerived> & A)217 Derived& factorize(const EigenBase<MatrixDerived>& A) 218 { 219 eigen_assert(m_analysisIsOk && "You must first call analyzePattern()"); 220 grab(A.derived()); 221 m_preconditioner.factorize(matrix()); 222 m_factorizationIsOk = true; 223 m_info = m_preconditioner.info(); 224 return derived(); 225 } 226 227 /** Initializes the iterative solver with the matrix \a A for further solving \c Ax=b problems. 228 * 229 * Currently, this function mostly initializes/computes the preconditioner. In the future 230 * we might, for instance, implement column reordering for faster matrix vector products. 231 * 232 * \warning this class stores a reference to the matrix A as well as some 233 * precomputed values that depend on it. Therefore, if \a A is changed 234 * this class becomes invalid. Call compute() to update it with the new 235 * matrix A, or modify a copy of A. 236 */ 237 template<typename MatrixDerived> compute(const EigenBase<MatrixDerived> & A)238 Derived& compute(const EigenBase<MatrixDerived>& A) 239 { 240 grab(A.derived()); 241 m_preconditioner.compute(matrix()); 242 m_isInitialized = true; 243 m_analysisIsOk = true; 244 m_factorizationIsOk = true; 245 m_info = m_preconditioner.info(); 246 return derived(); 247 } 248 249 /** \internal */ rows()250 EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); } 251 252 /** \internal */ cols()253 EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); } 254 255 /** \returns the tolerance threshold used by the stopping criteria. 256 * \sa setTolerance() 257 */ tolerance()258 RealScalar tolerance() const { return m_tolerance; } 259 260 /** Sets the tolerance threshold used by the stopping criteria. 261 * 262 * This value is used as an upper bound to the relative residual error: |Ax-b|/|b|. 263 * The default value is the machine precision given by NumTraits<Scalar>::epsilon() 264 */ setTolerance(const RealScalar & tolerance)265 Derived& setTolerance(const RealScalar& tolerance) 266 { 267 m_tolerance = tolerance; 268 return derived(); 269 } 270 271 /** \returns a read-write reference to the preconditioner for custom configuration. */ preconditioner()272 Preconditioner& preconditioner() { return m_preconditioner; } 273 274 /** \returns a read-only reference to the preconditioner. */ preconditioner()275 const Preconditioner& preconditioner() const { return m_preconditioner; } 276 277 /** \returns the max number of iterations. 278 * It is either the value set by setMaxIterations or, by default, 279 * twice the number of columns of the matrix. 280 */ maxIterations()281 Index maxIterations() const 282 { 283 return (m_maxIterations<0) ? 2*matrix().cols() : m_maxIterations; 284 } 285 286 /** Sets the max number of iterations. 287 * Default is twice the number of columns of the matrix. 288 */ setMaxIterations(Index maxIters)289 Derived& setMaxIterations(Index maxIters) 290 { 291 m_maxIterations = maxIters; 292 return derived(); 293 } 294 295 /** \returns the number of iterations performed during the last solve */ iterations()296 Index iterations() const 297 { 298 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); 299 return m_iterations; 300 } 301 302 /** \returns the tolerance error reached during the last solve. 303 * It is a close approximation of the true relative residual error |Ax-b|/|b|. 304 */ error()305 RealScalar error() const 306 { 307 eigen_assert(m_isInitialized && "ConjugateGradient is not initialized."); 308 return m_error; 309 } 310 311 /** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A 312 * and \a x0 as an initial solution. 313 * 314 * \sa solve(), compute() 315 */ 316 template<typename Rhs,typename Guess> 317 inline const SolveWithGuess<Derived, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs> & b,const Guess & x0)318 solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const 319 { 320 eigen_assert(m_isInitialized && "Solver is not initialized."); 321 eigen_assert(derived().rows()==b.rows() && "solve(): invalid number of rows of the right hand side matrix b"); 322 return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0); 323 } 324 325 /** \returns Success if the iterations converged, and NoConvergence otherwise. */ info()326 ComputationInfo info() const 327 { 328 eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized."); 329 return m_info; 330 } 331 332 /** \internal */ 333 template<typename Rhs, typename DestDerived> _solve_with_guess_impl(const Rhs & b,SparseMatrixBase<DestDerived> & aDest)334 void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const 335 { 336 eigen_assert(rows()==b.rows()); 337 338 Index rhsCols = b.cols(); 339 Index size = b.rows(); 340 DestDerived& dest(aDest.derived()); 341 typedef typename DestDerived::Scalar DestScalar; 342 Eigen::Matrix<DestScalar,Dynamic,1> tb(size); 343 Eigen::Matrix<DestScalar,Dynamic,1> tx(cols()); 344 // We do not directly fill dest because sparse expressions have to be free of aliasing issue. 345 // For non square least-square problems, b and dest might not have the same size whereas they might alias each-other. 346 typename DestDerived::PlainObject tmp(cols(),rhsCols); 347 ComputationInfo global_info = Success; 348 for(Index k=0; k<rhsCols; ++k) 349 { 350 tb = b.col(k); 351 tx = dest.col(k); 352 derived()._solve_vector_with_guess_impl(tb,tx); 353 tmp.col(k) = tx.sparseView(0); 354 355 // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column 356 // we need to restore it to the worst value. 357 if(m_info==NumericalIssue) 358 global_info = NumericalIssue; 359 else if(m_info==NoConvergence) 360 global_info = NoConvergence; 361 } 362 m_info = global_info; 363 dest.swap(tmp); 364 } 365 366 template<typename Rhs, typename DestDerived> 367 typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type _solve_with_guess_impl(const Rhs & b,MatrixBase<DestDerived> & aDest)368 _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const 369 { 370 eigen_assert(rows()==b.rows()); 371 372 Index rhsCols = b.cols(); 373 DestDerived& dest(aDest.derived()); 374 ComputationInfo global_info = Success; 375 for(Index k=0; k<rhsCols; ++k) 376 { 377 typename DestDerived::ColXpr xk(dest,k); 378 typename Rhs::ConstColXpr bk(b,k); 379 derived()._solve_vector_with_guess_impl(bk,xk); 380 381 // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column 382 // we need to restore it to the worst value. 383 if(m_info==NumericalIssue) 384 global_info = NumericalIssue; 385 else if(m_info==NoConvergence) 386 global_info = NoConvergence; 387 } 388 m_info = global_info; 389 } 390 391 template<typename Rhs, typename DestDerived> 392 typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type _solve_with_guess_impl(const Rhs & b,MatrixBase<DestDerived> & dest)393 _solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const 394 { 395 derived()._solve_vector_with_guess_impl(b,dest.derived()); 396 } 397 398 /** \internal default initial guess = 0 */ 399 template<typename Rhs,typename Dest> _solve_impl(const Rhs & b,Dest & x)400 void _solve_impl(const Rhs& b, Dest& x) const 401 { 402 x.setZero(); 403 derived()._solve_with_guess_impl(b,x); 404 } 405 406 protected: init()407 void init() 408 { 409 m_isInitialized = false; 410 m_analysisIsOk = false; 411 m_factorizationIsOk = false; 412 m_maxIterations = -1; 413 m_tolerance = NumTraits<Scalar>::epsilon(); 414 } 415 416 typedef internal::generic_matrix_wrapper<MatrixType> MatrixWrapper; 417 typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType; 418 matrix()419 const ActualMatrixType& matrix() const 420 { 421 return m_matrixWrapper.matrix(); 422 } 423 424 template<typename InputType> grab(const InputType & A)425 void grab(const InputType &A) 426 { 427 m_matrixWrapper.grab(A); 428 } 429 430 MatrixWrapper m_matrixWrapper; 431 Preconditioner m_preconditioner; 432 433 Index m_maxIterations; 434 RealScalar m_tolerance; 435 436 mutable RealScalar m_error; 437 mutable Index m_iterations; 438 mutable ComputationInfo m_info; 439 mutable bool m_analysisIsOk, m_factorizationIsOk; 440 }; 441 442 } // end namespace Eigen 443 444 #endif // EIGEN_ITERATIVE_SOLVER_BASE_H 445