1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H 11 #define EIGEN_TRIANGULARMATRIXVECTOR_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized> 18 struct triangular_matrix_vector_product; 19 20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 22 { 23 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; 24 enum { 25 IsLower = ((Mode&Lower)==Lower), 26 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 27 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 28 }; 29 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 30 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha); 31 }; 32 33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> 34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> 35 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 36 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha) 37 { 38 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 39 Index size = (std::min)(_rows,_cols); 40 Index rows = IsLower ? _rows : (std::min)(_rows,_cols); 41 Index cols = IsLower ? (std::min)(_rows,_cols) : _cols; 42 43 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; 44 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 45 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 46 47 typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; 48 const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); 49 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 50 51 typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; 52 ResMap res(_res,rows); 53 54 typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper; 55 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; 56 57 for (Index pi=0; pi<size; pi+=PanelWidth) 58 { 59 Index actualPanelWidth = (std::min)(PanelWidth, size-pi); 60 for (Index k=0; k<actualPanelWidth; ++k) 61 { 62 Index i = pi + k; 63 Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi; 64 Index r = IsLower ? actualPanelWidth-k : k+1; 65 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 66 res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r); 67 if (HasUnitDiag) 68 res.coeffRef(i) += alpha * cjRhs.coeff(i); 69 } 70 Index r = IsLower ? rows - pi - actualPanelWidth : pi; 71 if (r>0) 72 { 73 Index s = IsLower ? pi+actualPanelWidth : 0; 74 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run( 75 r, actualPanelWidth, 76 LhsMapper(&lhs.coeffRef(s,pi), lhsStride), 77 RhsMapper(&rhs.coeffRef(pi), rhsIncr), 78 &res.coeffRef(s), resIncr, alpha); 79 } 80 } 81 if((!IsLower) && cols>size) 82 { 83 general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run( 84 rows, cols-size, 85 LhsMapper(&lhs.coeffRef(0,size), lhsStride), 86 RhsMapper(&rhs.coeffRef(size), rhsIncr), 87 _res, resIncr, alpha); 88 } 89 } 90 91 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 92 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 93 { 94 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; 95 enum { 96 IsLower = ((Mode&Lower)==Lower), 97 HasUnitDiag = (Mode & UnitDiag)==UnitDiag, 98 HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag 99 }; 100 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 101 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha); 102 }; 103 104 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> 105 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> 106 ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, 107 const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha) 108 { 109 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; 110 Index diagSize = (std::min)(_rows,_cols); 111 Index rows = IsLower ? _rows : diagSize; 112 Index cols = IsLower ? diagSize : _cols; 113 114 typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; 115 const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); 116 typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); 117 118 typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap; 119 const RhsMap rhs(_rhs,cols); 120 typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); 121 122 typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; 123 ResMap res(_res,rows,InnerStride<>(resIncr)); 124 125 typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper; 126 typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; 127 128 for (Index pi=0; pi<diagSize; pi+=PanelWidth) 129 { 130 Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi); 131 for (Index k=0; k<actualPanelWidth; ++k) 132 { 133 Index i = pi + k; 134 Index s = IsLower ? pi : ((HasUnitDiag||HasZeroDiag) ? i+1 : i); 135 Index r = IsLower ? k+1 : actualPanelWidth-k; 136 if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0) 137 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum(); 138 if (HasUnitDiag) 139 res.coeffRef(i) += alpha * cjRhs.coeff(i); 140 } 141 Index r = IsLower ? pi : cols - pi - actualPanelWidth; 142 if (r>0) 143 { 144 Index s = IsLower ? 0 : pi + actualPanelWidth; 145 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run( 146 actualPanelWidth, r, 147 LhsMapper(&lhs.coeffRef(pi,s), lhsStride), 148 RhsMapper(&rhs.coeffRef(s), rhsIncr), 149 &res.coeffRef(pi), resIncr, alpha); 150 } 151 } 152 if(IsLower && rows>diagSize) 153 { 154 general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run( 155 rows-diagSize, cols, 156 LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride), 157 RhsMapper(&rhs.coeffRef(0), rhsIncr), 158 &res.coeffRef(diagSize), resIncr, alpha); 159 } 160 } 161 162 /*************************************************************************** 163 * Wrapper to product_triangular_vector 164 ***************************************************************************/ 165 166 template<int Mode,int StorageOrder> 167 struct trmv_selector; 168 169 } // end namespace internal 170 171 namespace internal { 172 173 template<int Mode, typename Lhs, typename Rhs> 174 struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true> 175 { 176 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) 177 { 178 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); 179 180 internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha); 181 } 182 }; 183 184 template<int Mode, typename Lhs, typename Rhs> 185 struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false> 186 { 187 template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha) 188 { 189 eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols()); 190 191 Transpose<Dest> dstT(dst); 192 internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower), 193 (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor> 194 ::run(rhs.transpose(),lhs.transpose(), dstT, alpha); 195 } 196 }; 197 198 } // end namespace internal 199 200 namespace internal { 201 202 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. 203 204 template<int Mode> struct trmv_selector<Mode,ColMajor> 205 { 206 template<typename Lhs, typename Rhs, typename Dest> 207 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) 208 { 209 typedef typename Lhs::Scalar LhsScalar; 210 typedef typename Rhs::Scalar RhsScalar; 211 typedef typename Dest::Scalar ResScalar; 212 typedef typename Dest::RealScalar RealScalar; 213 214 typedef internal::blas_traits<Lhs> LhsBlasTraits; 215 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; 216 typedef internal::blas_traits<Rhs> RhsBlasTraits; 217 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; 218 219 typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest; 220 221 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); 222 typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); 223 224 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) 225 * RhsBlasTraits::extractScalarFactor(rhs); 226 227 enum { 228 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 229 // on, the other hand it is good for the cache to pack the vector anyways... 230 EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, 231 ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex), 232 MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal 233 }; 234 235 gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; 236 237 bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); 238 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; 239 240 RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); 241 242 ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), 243 evalToDest ? dest.data() : static_dest.data()); 244 245 if(!evalToDest) 246 { 247 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 248 Index size = dest.size(); 249 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 250 #endif 251 if(!alphaIsCompatible) 252 { 253 MappedDest(actualDestPtr, dest.size()).setZero(); 254 compatibleAlpha = RhsScalar(1); 255 } 256 else 257 MappedDest(actualDestPtr, dest.size()) = dest; 258 } 259 260 internal::triangular_matrix_vector_product 261 <Index,Mode, 262 LhsScalar, LhsBlasTraits::NeedToConjugate, 263 RhsScalar, RhsBlasTraits::NeedToConjugate, 264 ColMajor> 265 ::run(actualLhs.rows(),actualLhs.cols(), 266 actualLhs.data(),actualLhs.outerStride(), 267 actualRhs.data(),actualRhs.innerStride(), 268 actualDestPtr,1,compatibleAlpha); 269 270 if (!evalToDest) 271 { 272 if(!alphaIsCompatible) 273 dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); 274 else 275 dest = MappedDest(actualDestPtr, dest.size()); 276 } 277 } 278 }; 279 280 template<int Mode> struct trmv_selector<Mode,RowMajor> 281 { 282 template<typename Lhs, typename Rhs, typename Dest> 283 static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha) 284 { 285 typedef typename Lhs::Scalar LhsScalar; 286 typedef typename Rhs::Scalar RhsScalar; 287 typedef typename Dest::Scalar ResScalar; 288 289 typedef internal::blas_traits<Lhs> LhsBlasTraits; 290 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; 291 typedef internal::blas_traits<Rhs> RhsBlasTraits; 292 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; 293 typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned; 294 295 typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs); 296 typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs); 297 298 ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs) 299 * RhsBlasTraits::extractScalarFactor(rhs); 300 301 enum { 302 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1 303 }; 304 305 gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; 306 307 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), 308 DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); 309 310 if(!DirectlyUseRhs) 311 { 312 #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN 313 Index size = actualRhs.size(); 314 EIGEN_DENSE_STORAGE_CTOR_PLUGIN 315 #endif 316 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; 317 } 318 319 internal::triangular_matrix_vector_product 320 <Index,Mode, 321 LhsScalar, LhsBlasTraits::NeedToConjugate, 322 RhsScalar, RhsBlasTraits::NeedToConjugate, 323 RowMajor> 324 ::run(actualLhs.rows(),actualLhs.cols(), 325 actualLhs.data(),actualLhs.outerStride(), 326 actualRhsPtr,1, 327 dest.data(),dest.innerStride(), 328 actualAlpha); 329 } 330 }; 331 332 } // end namespace internal 333 334 } // end namespace Eigen 335 336 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H 337