• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SOLVERBASE_H
11 #define EIGEN_SOLVERBASE_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Derived>
18 struct solve_assertion {
19     template<bool Transpose_, typename Rhs>
runsolve_assertion20     static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); }
21 };
22 
23 template<typename Derived>
24 struct solve_assertion<Transpose<Derived> >
25 {
26     typedef Transpose<Derived> type;
27 
28     template<bool Transpose_, typename Rhs>
29     static void run(const type& transpose, const Rhs& b)
30     {
31         internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b);
32     }
33 };
34 
35 template<typename Scalar, typename Derived>
36 struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > >
37 {
38     typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type;
39 
40     template<bool Transpose_, typename Rhs>
41     static void run(const type& adjoint, const Rhs& b)
42     {
43         internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b);
44     }
45 };
46 } // end namespace internal
47 
48 /** \class SolverBase
49   * \brief A base class for matrix decomposition and solvers
50   *
51   * \tparam Derived the actual type of the decomposition/solver.
52   *
53   * Any matrix decomposition inheriting this base class provide the following API:
54   *
55   * \code
56   * MatrixType A, b, x;
57   * DecompositionType dec(A);
58   * x = dec.solve(b);             // solve A   * x = b
59   * x = dec.transpose().solve(b); // solve A^T * x = b
60   * x = dec.adjoint().solve(b);   // solve A'  * x = b
61   * \endcode
62   *
63   * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors.
64   *
65   * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
66   */
67 template<typename Derived>
68 class SolverBase : public EigenBase<Derived>
69 {
70   public:
71 
72     typedef EigenBase<Derived> Base;
73     typedef typename internal::traits<Derived>::Scalar Scalar;
74     typedef Scalar CoeffReturnType;
75 
76     template<typename Derived_>
77     friend struct internal::solve_assertion;
78 
79     enum {
80       RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
81       ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
82       SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime,
83                                                           internal::traits<Derived>::ColsAtCompileTime>::ret),
84       MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
85       MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
86       MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime,
87                                                              internal::traits<Derived>::MaxColsAtCompileTime>::ret),
88       IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1
89                            || internal::traits<Derived>::MaxColsAtCompileTime == 1,
90       NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2
91     };
92 
93     /** Default constructor */
94     SolverBase()
95     {}
96 
97     ~SolverBase()
98     {}
99 
100     using Base::derived;
101 
102     /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
103       */
104     template<typename Rhs>
105     inline const Solve<Derived, Rhs>
106     solve(const MatrixBase<Rhs>& b) const
107     {
108       internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b);
109       return Solve<Derived, Rhs>(derived(), b.derived());
110     }
111 
112     /** \internal the return type of transpose() */
113     typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType;
114     /** \returns an expression of the transposed of the factored matrix.
115       *
116       * A typical usage is to solve for the transposed problem A^T x = b:
117       * \code x = dec.transpose().solve(b); \endcode
118       *
119       * \sa adjoint(), solve()
120       */
121     inline ConstTransposeReturnType transpose() const
122     {
123       return ConstTransposeReturnType(derived());
124     }
125 
126     /** \internal the return type of adjoint() */
127     typedef typename internal::conditional<NumTraits<Scalar>::IsComplex,
128                         CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>,
129                         ConstTransposeReturnType
130                      >::type AdjointReturnType;
131     /** \returns an expression of the adjoint of the factored matrix
132       *
133       * A typical usage is to solve for the adjoint problem A' x = b:
134       * \code x = dec.adjoint().solve(b); \endcode
135       *
136       * For real scalar types, this function is equivalent to transpose().
137       *
138       * \sa transpose(), solve()
139       */
140     inline AdjointReturnType adjoint() const
141     {
142       return AdjointReturnType(derived().transpose());
143     }
144 
145   protected:
146 
147     template<bool Transpose_, typename Rhs>
148     void _check_solve_assertion(const Rhs& b) const {
149         EIGEN_ONLY_USED_FOR_DEBUG(b);
150         eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
151         eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b");
152     }
153 };
154 
155 namespace internal {
156 
157 template<typename Derived>
158 struct generic_xpr_base<Derived, MatrixXpr, SolverStorage>
159 {
160   typedef SolverBase<Derived> type;
161 
162 };
163 
164 } // end namespace internal
165 
166 } // end namespace Eigen
167 
168 #endif // EIGEN_SOLVERBASE_H
169