• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18 struct triangular_matrix_vector_product;
19 
20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22 {
23   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24   enum {
25     IsLower = ((Mode&Lower)==Lower),
26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28   };
29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
31   {
32     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
33     Index size = (std::min)(_rows,_cols);
34     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
35     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
36 
37     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
38     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
39     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
40 
41     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
42     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
43     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
44 
45     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
46     ResMap res(_res,rows);
47 
48     for (Index pi=0; pi<size; pi+=PanelWidth)
49     {
50       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
51       for (Index k=0; k<actualPanelWidth; ++k)
52       {
53         Index i = pi + k;
54         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
55         Index r = IsLower ? actualPanelWidth-k : k+1;
56         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
57           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
58         if (HasUnitDiag)
59           res.coeffRef(i) += alpha * cjRhs.coeff(i);
60       }
61       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
62       if (r>0)
63       {
64         Index s = IsLower ? pi+actualPanelWidth : 0;
65         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
66             r, actualPanelWidth,
67             &lhs.coeffRef(s,pi), lhsStride,
68             &rhs.coeffRef(pi), rhsIncr,
69             &res.coeffRef(s), resIncr, alpha);
70       }
71     }
72     if((!IsLower) && cols>size)
73     {
74       general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
75           rows, cols-size,
76           &lhs.coeffRef(0,size), lhsStride,
77           &rhs.coeffRef(size), rhsIncr,
78           _res, resIncr, alpha);
79     }
80   }
81 };
82 
83 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
84 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
85 {
86   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
87   enum {
88     IsLower = ((Mode&Lower)==Lower),
89     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
90     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
91   };
92   static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
93                   const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
94   {
95     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
96     Index diagSize = (std::min)(_rows,_cols);
97     Index rows = IsLower ? _rows : diagSize;
98     Index cols = IsLower ? diagSize : _cols;
99 
100     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
101     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
102     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
103 
104     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
105     const RhsMap rhs(_rhs,cols);
106     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
107 
108     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
109     ResMap res(_res,rows,InnerStride<>(resIncr));
110 
111     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
112     {
113       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
114       for (Index k=0; k<actualPanelWidth; ++k)
115       {
116         Index i = pi + k;
117         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
118         Index r = IsLower ? k+1 : actualPanelWidth-k;
119         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
120           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
121         if (HasUnitDiag)
122           res.coeffRef(i) += alpha * cjRhs.coeff(i);
123       }
124       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
125       if (r>0)
126       {
127         Index s = IsLower ? 0 : pi + actualPanelWidth;
128         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
129             actualPanelWidth, r,
130             &lhs.coeffRef(pi,s), lhsStride,
131             &rhs.coeffRef(s), rhsIncr,
132             &res.coeffRef(pi), resIncr, alpha);
133       }
134     }
135     if(IsLower && rows>diagSize)
136     {
137       general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
138             rows-diagSize, cols,
139             &lhs.coeffRef(diagSize,0), lhsStride,
140             &rhs.coeffRef(0), rhsIncr,
141             &res.coeffRef(diagSize), resIncr, alpha);
142     }
143   }
144 };
145 
146 /***************************************************************************
147 * Wrapper to product_triangular_vector
148 ***************************************************************************/
149 
150 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
151 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
152  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
153 {};
154 
155 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
156 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
157  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
158 {};
159 
160 
161 template<int StorageOrder>
162 struct trmv_selector;
163 
164 } // end namespace internal
165 
166 template<int Mode, typename Lhs, typename Rhs>
167 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
168   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
169 {
170   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
171 
172   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
173 
174   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
175   {
176     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
177 
178     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
179   }
180 };
181 
182 template<int Mode, typename Lhs, typename Rhs>
183 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
184   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
185 {
186   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
187 
188   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
189 
190   template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
191   {
192     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
193 
194     typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
195     Transpose<Dest> dstT(dst);
196     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
197       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
198   }
199 };
200 
201 namespace internal {
202 
203 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
204 
205 template<> struct trmv_selector<ColMajor>
206 {
207   template<int Mode, typename Lhs, typename Rhs, typename Dest>
208   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
209   {
210     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
211     typedef typename ProductType::Index Index;
212     typedef typename ProductType::LhsScalar   LhsScalar;
213     typedef typename ProductType::RhsScalar   RhsScalar;
214     typedef typename ProductType::Scalar      ResScalar;
215     typedef typename ProductType::RealScalar  RealScalar;
216     typedef typename ProductType::ActualLhsType ActualLhsType;
217     typedef typename ProductType::ActualRhsType ActualRhsType;
218     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
219     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
220     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
221 
222     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
223     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
224 
225     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
226                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
227 
228     enum {
229       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
230       // on, the other hand it is good for the cache to pack the vector anyways...
231       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
232       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
233       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
234     };
235 
236     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
237 
238     bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
239     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
240 
241     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
242 
243     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
244                                                   evalToDest ? dest.data() : static_dest.data());
245 
246     if(!evalToDest)
247     {
248       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
249       int size = dest.size();
250       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
251       #endif
252       if(!alphaIsCompatible)
253       {
254         MappedDest(actualDestPtr, dest.size()).setZero();
255         compatibleAlpha = RhsScalar(1);
256       }
257       else
258         MappedDest(actualDestPtr, dest.size()) = dest;
259     }
260 
261     internal::triangular_matrix_vector_product
262       <Index,Mode,
263        LhsScalar, LhsBlasTraits::NeedToConjugate,
264        RhsScalar, RhsBlasTraits::NeedToConjugate,
265        ColMajor>
266       ::run(actualLhs.rows(),actualLhs.cols(),
267             actualLhs.data(),actualLhs.outerStride(),
268             actualRhs.data(),actualRhs.innerStride(),
269             actualDestPtr,1,compatibleAlpha);
270 
271     if (!evalToDest)
272     {
273       if(!alphaIsCompatible)
274         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
275       else
276         dest = MappedDest(actualDestPtr, dest.size());
277     }
278   }
279 };
280 
281 template<> struct trmv_selector<RowMajor>
282 {
283   template<int Mode, typename Lhs, typename Rhs, typename Dest>
284   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
285   {
286     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
287     typedef typename ProductType::LhsScalar LhsScalar;
288     typedef typename ProductType::RhsScalar RhsScalar;
289     typedef typename ProductType::Scalar    ResScalar;
290     typedef typename ProductType::Index Index;
291     typedef typename ProductType::ActualLhsType ActualLhsType;
292     typedef typename ProductType::ActualRhsType ActualRhsType;
293     typedef typename ProductType::_ActualRhsType _ActualRhsType;
294     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
295     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
296 
297     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
298     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
299 
300     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
301                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
302 
303     enum {
304       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
305     };
306 
307     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
308 
309     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
310         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
311 
312     if(!DirectlyUseRhs)
313     {
314       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
315       int size = actualRhs.size();
316       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
317       #endif
318       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
319     }
320 
321     internal::triangular_matrix_vector_product
322       <Index,Mode,
323        LhsScalar, LhsBlasTraits::NeedToConjugate,
324        RhsScalar, RhsBlasTraits::NeedToConjugate,
325        RowMajor>
326       ::run(actualLhs.rows(),actualLhs.cols(),
327             actualLhs.data(),actualLhs.outerStride(),
328             actualRhsPtr,1,
329             dest.data(),dest.innerStride(),
330             actualAlpha);
331   }
332 };
333 
334 } // end namespace internal
335 
336 } // end namespace Eigen
337 
338 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
339