1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SOLVERBASE_H 11 #define EIGEN_SOLVERBASE_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Derived> 18 struct solve_assertion { 19 template<bool Transpose_, typename Rhs> runsolve_assertion20 static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); } 21 }; 22 23 template<typename Derived> 24 struct solve_assertion<Transpose<Derived> > 25 { 26 typedef Transpose<Derived> type; 27 28 template<bool Transpose_, typename Rhs> 29 static void run(const type& transpose, const Rhs& b) 30 { 31 internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b); 32 } 33 }; 34 35 template<typename Scalar, typename Derived> 36 struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > > 37 { 38 typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type; 39 40 template<bool Transpose_, typename Rhs> 41 static void run(const type& adjoint, const Rhs& b) 42 { 43 internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b); 44 } 45 }; 46 } // end namespace internal 47 48 /** \class SolverBase 49 * \brief A base class for matrix decomposition and solvers 50 * 51 * \tparam Derived the actual type of the decomposition/solver. 52 * 53 * Any matrix decomposition inheriting this base class provide the following API: 54 * 55 * \code 56 * MatrixType A, b, x; 57 * DecompositionType dec(A); 58 * x = dec.solve(b); // solve A * x = b 59 * x = dec.transpose().solve(b); // solve A^T * x = b 60 * x = dec.adjoint().solve(b); // solve A' * x = b 61 * \endcode 62 * 63 * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors. 64 * 65 * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase 66 */ 67 template<typename Derived> 68 class SolverBase : public EigenBase<Derived> 69 { 70 public: 71 72 typedef EigenBase<Derived> Base; 73 typedef typename internal::traits<Derived>::Scalar Scalar; 74 typedef Scalar CoeffReturnType; 75 76 template<typename Derived_> 77 friend struct internal::solve_assertion; 78 79 enum { 80 RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime, 81 ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime, 82 SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime, 83 internal::traits<Derived>::ColsAtCompileTime>::ret), 84 MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime, 85 MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime, 86 MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime, 87 internal::traits<Derived>::MaxColsAtCompileTime>::ret), 88 IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1 89 || internal::traits<Derived>::MaxColsAtCompileTime == 1, 90 NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2 91 }; 92 93 /** Default constructor */ 94 SolverBase() 95 {} 96 97 ~SolverBase() 98 {} 99 100 using Base::derived; 101 102 /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A. 103 */ 104 template<typename Rhs> 105 inline const Solve<Derived, Rhs> 106 solve(const MatrixBase<Rhs>& b) const 107 { 108 internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b); 109 return Solve<Derived, Rhs>(derived(), b.derived()); 110 } 111 112 /** \internal the return type of transpose() */ 113 typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType; 114 /** \returns an expression of the transposed of the factored matrix. 115 * 116 * A typical usage is to solve for the transposed problem A^T x = b: 117 * \code x = dec.transpose().solve(b); \endcode 118 * 119 * \sa adjoint(), solve() 120 */ 121 inline ConstTransposeReturnType transpose() const 122 { 123 return ConstTransposeReturnType(derived()); 124 } 125 126 /** \internal the return type of adjoint() */ 127 typedef typename internal::conditional<NumTraits<Scalar>::IsComplex, 128 CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>, 129 ConstTransposeReturnType 130 >::type AdjointReturnType; 131 /** \returns an expression of the adjoint of the factored matrix 132 * 133 * A typical usage is to solve for the adjoint problem A' x = b: 134 * \code x = dec.adjoint().solve(b); \endcode 135 * 136 * For real scalar types, this function is equivalent to transpose(). 137 * 138 * \sa transpose(), solve() 139 */ 140 inline AdjointReturnType adjoint() const 141 { 142 return AdjointReturnType(derived().transpose()); 143 } 144 145 protected: 146 147 template<bool Transpose_, typename Rhs> 148 void _check_solve_assertion(const Rhs& b) const { 149 EIGEN_ONLY_USED_FOR_DEBUG(b); 150 eigen_assert(derived().m_isInitialized && "Solver is not initialized."); 151 eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b"); 152 } 153 }; 154 155 namespace internal { 156 157 template<typename Derived> 158 struct generic_xpr_base<Derived, MatrixXpr, SolverStorage> 159 { 160 typedef SolverBase<Derived> type; 161 162 }; 163 164 } // end namespace internal 165 166 } // end namespace Eigen 167 168 #endif // EIGEN_SOLVERBASE_H 169