• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
11 #define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate, int StorageOrder>
18 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Conjugate, StorageOrder>
19 {
20   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
21   {
22     triangular_solve_vector<LhsScalar,RhsScalar,Index,OnTheLeft,
23         ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag),
24         Conjugate,StorageOrder==RowMajor?ColMajor:RowMajor
25       >::run(size, _lhs, lhsStride, rhs);
26   }
27 };
28 
29 // forward and backward substitution, row-major, rhs is a vector
30 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
31 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
32 {
33   enum {
34     IsLower = ((Mode&Lower)==Lower)
35   };
36   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
37   {
38     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
39     const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
40 
41     typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
42     typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
43 
44     typename internal::conditional<
45                           Conjugate,
46                           const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
47                           const LhsMap&>
48                         ::type cjLhs(lhs);
49     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
50     for(Index pi=IsLower ? 0 : size;
51         IsLower ? pi<size : pi>0;
52         IsLower ? pi+=PanelWidth : pi-=PanelWidth)
53     {
54       Index actualPanelWidth = (std::min)(IsLower ? size - pi : pi, PanelWidth);
55 
56       Index r = IsLower ? pi : size - pi; // remaining size
57       if (r > 0)
58       {
59         // let's directly call the low level product function because:
60         // 1 - it is faster to compile
61         // 2 - it is slighlty faster at runtime
62         Index startRow = IsLower ? pi : pi-actualPanelWidth;
63         Index startCol = IsLower ? 0 : pi;
64 
65         general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
66           actualPanelWidth, r,
67           LhsMapper(&lhs.coeffRef(startRow,startCol), lhsStride),
68           RhsMapper(rhs + startCol, 1),
69           rhs + startRow, 1,
70           RhsScalar(-1));
71       }
72 
73       for(Index k=0; k<actualPanelWidth; ++k)
74       {
75         Index i = IsLower ? pi+k : pi-k-1;
76         Index s = IsLower ? pi   : i+1;
77         if (k>0)
78           rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<const Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
79 
80         if(!(Mode & UnitDiag))
81           rhs[i] /= cjLhs(i,i);
82       }
83     }
84   }
85 };
86 
87 // forward and backward substitution, column-major, rhs is a vector
88 template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
89 struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, ColMajor>
90 {
91   enum {
92     IsLower = ((Mode&Lower)==Lower)
93   };
94   static void run(Index size, const LhsScalar* _lhs, Index lhsStride, RhsScalar* rhs)
95   {
96     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
97     const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
98     typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
99     typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
100     typename internal::conditional<Conjugate,
101                                    const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
102                                    const LhsMap&
103                                   >::type cjLhs(lhs);
104     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
105 
106     for(Index pi=IsLower ? 0 : size;
107         IsLower ? pi<size : pi>0;
108         IsLower ? pi+=PanelWidth : pi-=PanelWidth)
109     {
110       Index actualPanelWidth = (std::min)(IsLower ? size - pi : pi, PanelWidth);
111       Index startBlock = IsLower ? pi : pi-actualPanelWidth;
112       Index endBlock = IsLower ? pi + actualPanelWidth : 0;
113 
114       for(Index k=0; k<actualPanelWidth; ++k)
115       {
116         Index i = IsLower ? pi+k : pi-k-1;
117         if(!(Mode & UnitDiag))
118           rhs[i] /= cjLhs.coeff(i,i);
119 
120         Index r = actualPanelWidth - k - 1; // remaining size
121         Index s = IsLower ? i+1 : i-r;
122         if (r>0)
123           Map<Matrix<RhsScalar,Dynamic,1> >(rhs+s,r) -= rhs[i] * cjLhs.col(i).segment(s,r);
124       }
125       Index r = IsLower ? size - endBlock : startBlock; // remaining size
126       if (r > 0)
127       {
128         // let's directly call the low level product function because:
129         // 1 - it is faster to compile
130         // 2 - it is slighlty faster at runtime
131         general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
132             r, actualPanelWidth,
133             LhsMapper(&lhs.coeffRef(endBlock,startBlock), lhsStride),
134             RhsMapper(rhs+startBlock, 1),
135             rhs+endBlock, 1, RhsScalar(-1));
136       }
137     }
138   }
139 };
140 
141 } // end namespace internal
142 
143 } // end namespace Eigen
144 
145 #endif // EIGEN_TRIANGULAR_SOLVER_VECTOR_H
146