1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SELFADJOINT_PRODUCT_H 11 #define EIGEN_SELFADJOINT_PRODUCT_H 12 13 /********************************************************************** 14 * This file implements a self adjoint product: C += A A^T updating only 15 * half of the selfadjoint matrix C. 16 * It corresponds to the level 3 SYRK and level 2 SYR Blas routines. 17 **********************************************************************/ 18 19 namespace Eigen { 20 21 22 template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> 23 struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs> 24 { 25 static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) 26 { 27 internal::conj_if<ConjRhs> cj; 28 typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap; 29 typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType; 30 for (Index i=0; i<size; ++i) 31 { 32 Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) 33 += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); 34 } 35 } 36 }; 37 38 template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> 39 struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs> 40 { 41 static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) 42 { 43 selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha); 44 } 45 }; 46 47 template<typename MatrixType, typename OtherType, int UpLo, bool OtherIsVector = OtherType::IsVectorAtCompileTime> 48 struct selfadjoint_product_selector; 49 50 template<typename MatrixType, typename OtherType, int UpLo> 51 struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true> 52 { 53 static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) 54 { 55 typedef typename MatrixType::Scalar Scalar; 56 typedef internal::blas_traits<OtherType> OtherBlasTraits; 57 typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; 58 typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; 59 typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived()); 60 61 Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); 62 63 enum { 64 StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor, 65 UseOtherDirectly = _ActualOtherType::InnerStrideAtCompileTime==1 66 }; 67 internal::gemv_static_vector_if<Scalar,OtherType::SizeAtCompileTime,OtherType::MaxSizeAtCompileTime,!UseOtherDirectly> static_other; 68 69 ei_declare_aligned_stack_constructed_variable(Scalar, actualOtherPtr, other.size(), 70 (UseOtherDirectly ? const_cast<Scalar*>(actualOther.data()) : static_other.data())); 71 72 if(!UseOtherDirectly) 73 Map<typename _ActualOtherType::PlainObject>(actualOtherPtr, actualOther.size()) = actualOther; 74 75 selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo, 76 OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, 77 (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex> 78 ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha); 79 } 80 }; 81 82 template<typename MatrixType, typename OtherType, int UpLo> 83 struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> 84 { 85 static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) 86 { 87 typedef typename MatrixType::Scalar Scalar; 88 typedef internal::blas_traits<OtherType> OtherBlasTraits; 89 typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; 90 typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; 91 typename internal::add_const_on_value_type<ActualOtherType>::type actualOther = OtherBlasTraits::extract(other.derived()); 92 93 Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); 94 95 enum { 96 IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0, 97 OtherIsRowMajor = _ActualOtherType::Flags&RowMajorBit ? 1 : 0 98 }; 99 100 Index size = mat.cols(); 101 Index depth = actualOther.cols(); 102 103 typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,Scalar,Scalar, 104 MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualOtherType::MaxColsAtCompileTime> BlockingType; 105 106 BlockingType blocking(size, size, depth, 1, false); 107 108 109 internal::general_matrix_matrix_triangular_product<Index, 110 Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, 111 Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, 112 IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo> 113 ::run(size, depth, 114 actualOther.data(), actualOther.outerStride(), actualOther.data(), actualOther.outerStride(), 115 mat.data(), mat.innerStride(), mat.outerStride(), actualAlpha, blocking); 116 } 117 }; 118 119 // high level API 120 121 template<typename MatrixType, unsigned int UpLo> 122 template<typename DerivedU> 123 EIGEN_DEVICE_FUNC SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo> 124 ::rankUpdate(const MatrixBase<DerivedU>& u, const Scalar& alpha) 125 { 126 selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha); 127 128 return *this; 129 } 130 131 } // end namespace Eigen 132 133 #endif // EIGEN_SELFADJOINT_PRODUCT_H 134