1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SOLVE_H 11 #define EIGEN_SOLVE_H 12 13 namespace Eigen { 14 15 template<typename Decomposition, typename RhsType, typename StorageKind> class SolveImpl; 16 17 /** \class Solve 18 * \ingroup Core_Module 19 * 20 * \brief Pseudo expression representing a solving operation 21 * 22 * \tparam Decomposition the type of the matrix or decomposition object 23 * \tparam Rhstype the type of the right-hand side 24 * 25 * This class represents an expression of A.solve(B) 26 * and most of the time this is the only way it is used. 27 * 28 */ 29 namespace internal { 30 31 // this solve_traits class permits to determine the evaluation type with respect to storage kind (Dense vs Sparse) 32 template<typename Decomposition, typename RhsType,typename StorageKind> struct solve_traits; 33 34 template<typename Decomposition, typename RhsType> 35 struct solve_traits<Decomposition,RhsType,Dense> 36 { 37 typedef typename make_proper_matrix_type<typename RhsType::Scalar, 38 Decomposition::ColsAtCompileTime, 39 RhsType::ColsAtCompileTime, 40 RhsType::PlainObject::Options, 41 Decomposition::MaxColsAtCompileTime, 42 RhsType::MaxColsAtCompileTime>::type PlainObject; 43 }; 44 45 template<typename Decomposition, typename RhsType> 46 struct traits<Solve<Decomposition, RhsType> > 47 : traits<typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject> 48 { 49 typedef typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject PlainObject; 50 typedef typename promote_index_type<typename Decomposition::StorageIndex, typename RhsType::StorageIndex>::type StorageIndex; 51 typedef traits<PlainObject> BaseTraits; 52 enum { 53 Flags = BaseTraits::Flags & RowMajorBit, 54 CoeffReadCost = HugeCost 55 }; 56 }; 57 58 } 59 60 61 template<typename Decomposition, typename RhsType> 62 class Solve : public SolveImpl<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind> 63 { 64 public: 65 typedef typename internal::traits<Solve>::PlainObject PlainObject; 66 typedef typename internal::traits<Solve>::StorageIndex StorageIndex; 67 68 Solve(const Decomposition &dec, const RhsType &rhs) 69 : m_dec(dec), m_rhs(rhs) 70 {} 71 72 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return m_dec.cols(); } 73 EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return m_rhs.cols(); } 74 75 EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; } 76 EIGEN_DEVICE_FUNC const RhsType& rhs() const { return m_rhs; } 77 78 protected: 79 const Decomposition &m_dec; 80 const RhsType &m_rhs; 81 }; 82 83 84 // Specialization of the Solve expression for dense results 85 template<typename Decomposition, typename RhsType> 86 class SolveImpl<Decomposition,RhsType,Dense> 87 : public MatrixBase<Solve<Decomposition,RhsType> > 88 { 89 typedef Solve<Decomposition,RhsType> Derived; 90 91 public: 92 93 typedef MatrixBase<Solve<Decomposition,RhsType> > Base; 94 EIGEN_DENSE_PUBLIC_INTERFACE(Derived) 95 96 private: 97 98 Scalar coeff(Index row, Index col) const; 99 Scalar coeff(Index i) const; 100 }; 101 102 // Generic API dispatcher 103 template<typename Decomposition, typename RhsType, typename StorageKind> 104 class SolveImpl : public internal::generic_xpr_base<Solve<Decomposition,RhsType>, MatrixXpr, StorageKind>::type 105 { 106 public: 107 typedef typename internal::generic_xpr_base<Solve<Decomposition,RhsType>, MatrixXpr, StorageKind>::type Base; 108 }; 109 110 namespace internal { 111 112 // Evaluator of Solve -> eval into a temporary 113 template<typename Decomposition, typename RhsType> 114 struct evaluator<Solve<Decomposition,RhsType> > 115 : public evaluator<typename Solve<Decomposition,RhsType>::PlainObject> 116 { 117 typedef Solve<Decomposition,RhsType> SolveType; 118 typedef typename SolveType::PlainObject PlainObject; 119 typedef evaluator<PlainObject> Base; 120 121 enum { Flags = Base::Flags | EvalBeforeNestingBit }; 122 123 EIGEN_DEVICE_FUNC explicit evaluator(const SolveType& solve) 124 : m_result(solve.rows(), solve.cols()) 125 { 126 ::new (static_cast<Base*>(this)) Base(m_result); 127 solve.dec()._solve_impl(solve.rhs(), m_result); 128 } 129 130 protected: 131 PlainObject m_result; 132 }; 133 134 // Specialization for "dst = dec.solve(rhs)" 135 // NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere 136 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar> 137 struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar,Scalar>, Dense2Dense> 138 { 139 typedef Solve<DecType,RhsType> SrcXprType; 140 static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &) 141 { 142 Index dstRows = src.rows(); 143 Index dstCols = src.cols(); 144 if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) 145 dst.resize(dstRows, dstCols); 146 147 src.dec()._solve_impl(src.rhs(), dst); 148 } 149 }; 150 151 // Specialization for "dst = dec.transpose().solve(rhs)" 152 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar> 153 struct Assignment<DstXprType, Solve<Transpose<const DecType>,RhsType>, internal::assign_op<Scalar,Scalar>, Dense2Dense> 154 { 155 typedef Solve<Transpose<const DecType>,RhsType> SrcXprType; 156 static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &) 157 { 158 Index dstRows = src.rows(); 159 Index dstCols = src.cols(); 160 if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) 161 dst.resize(dstRows, dstCols); 162 163 src.dec().nestedExpression().template _solve_impl_transposed<false>(src.rhs(), dst); 164 } 165 }; 166 167 // Specialization for "dst = dec.adjoint().solve(rhs)" 168 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar> 169 struct Assignment<DstXprType, Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType>, 170 internal::assign_op<Scalar,Scalar>, Dense2Dense> 171 { 172 typedef Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType> SrcXprType; 173 static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &) 174 { 175 Index dstRows = src.rows(); 176 Index dstCols = src.cols(); 177 if((dst.rows()!=dstRows) || (dst.cols()!=dstCols)) 178 dst.resize(dstRows, dstCols); 179 180 src.dec().nestedExpression().nestedExpression().template _solve_impl_transposed<true>(src.rhs(), dst); 181 } 182 }; 183 184 } // end namespace internal 185 186 } // end namespace Eigen 187 188 #endif // EIGEN_SOLVE_H 189