• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2009 Guillaume Saupin <guillaume.saupin@cea.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SKYLINEPRODUCT_H
11 #define EIGEN_SKYLINEPRODUCT_H
12 
13 namespace Eigen {
14 
15 template<typename Lhs, typename Rhs, int ProductMode>
16 struct SkylineProductReturnType {
17     typedef const typename internal::nested_eval<Lhs, Rhs::RowsAtCompileTime>::type LhsNested;
18     typedef const typename internal::nested_eval<Rhs, Lhs::RowsAtCompileTime>::type RhsNested;
19 
20     typedef SkylineProduct<LhsNested, RhsNested, ProductMode> Type;
21 };
22 
23 template<typename LhsNested, typename RhsNested, int ProductMode>
24 struct internal::traits<SkylineProduct<LhsNested, RhsNested, ProductMode> > {
25     // clean the nested types:
26     typedef typename internal::remove_all<LhsNested>::type _LhsNested;
27     typedef typename internal::remove_all<RhsNested>::type _RhsNested;
28     typedef typename _LhsNested::Scalar Scalar;
29 
30     enum {
31         LhsCoeffReadCost = _LhsNested::CoeffReadCost,
32         RhsCoeffReadCost = _RhsNested::CoeffReadCost,
33         LhsFlags = _LhsNested::Flags,
34         RhsFlags = _RhsNested::Flags,
35 
36         RowsAtCompileTime = _LhsNested::RowsAtCompileTime,
37         ColsAtCompileTime = _RhsNested::ColsAtCompileTime,
38         InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime),
39 
40         MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime,
41         MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime,
42 
43         EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit),
44         ResultIsSkyline = ProductMode == SkylineTimeSkylineProduct,
45 
46         RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit) | (ResultIsSkyline ? 0 : SkylineBit)),
47 
48         Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
49         | EvalBeforeAssigningBit
50         | EvalBeforeNestingBit,
51 
52         CoeffReadCost = HugeCost
53     };
54 
55     typedef typename internal::conditional<ResultIsSkyline,
56             SkylineMatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> >,
57             MatrixBase<SkylineProduct<LhsNested, RhsNested, ProductMode> > >::type Base;
58 };
59 
60 namespace internal {
61 template<typename LhsNested, typename RhsNested, int ProductMode>
62 class SkylineProduct : no_assignment_operator,
63 public traits<SkylineProduct<LhsNested, RhsNested, ProductMode> >::Base {
64 public:
65 
66     EIGEN_GENERIC_PUBLIC_INTERFACE(SkylineProduct)
67 
68 private:
69 
70     typedef typename traits<SkylineProduct>::_LhsNested _LhsNested;
71     typedef typename traits<SkylineProduct>::_RhsNested _RhsNested;
72 
73 public:
74 
75     template<typename Lhs, typename Rhs>
76     EIGEN_STRONG_INLINE SkylineProduct(const Lhs& lhs, const Rhs& rhs)
77     : m_lhs(lhs), m_rhs(rhs) {
78         eigen_assert(lhs.cols() == rhs.rows());
79 
80         enum {
81             ProductIsValid = _LhsNested::ColsAtCompileTime == Dynamic
82             || _RhsNested::RowsAtCompileTime == Dynamic
83             || int(_LhsNested::ColsAtCompileTime) == int(_RhsNested::RowsAtCompileTime),
84             AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime,
85             SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested, _RhsNested)
86         };
87         // note to the lost user:
88         //    * for a dot product use: v1.dot(v2)
89         //    * for a coeff-wise product use: v1.cwise()*v2
90         EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes),
91                 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS)
92                 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors),
93                 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION)
94                 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT)
95     }
96 
97     EIGEN_STRONG_INLINE Index rows() const {
98         return m_lhs.rows();
99     }
100 
101     EIGEN_STRONG_INLINE Index cols() const {
102         return m_rhs.cols();
103     }
104 
105     EIGEN_STRONG_INLINE const _LhsNested& lhs() const {
106         return m_lhs;
107     }
108 
109     EIGEN_STRONG_INLINE const _RhsNested& rhs() const {
110         return m_rhs;
111     }
112 
113 protected:
114     LhsNested m_lhs;
115     RhsNested m_rhs;
116 };
117 
118 // dense = skyline * dense
119 // Note that here we force no inlining and separate the setZero() because GCC messes up otherwise
120 
121 template<typename Lhs, typename Rhs, typename Dest>
122 EIGEN_DONT_INLINE void skyline_row_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
123     typedef typename remove_all<Lhs>::type _Lhs;
124     typedef typename remove_all<Rhs>::type _Rhs;
125     typedef typename traits<Lhs>::Scalar Scalar;
126 
127     enum {
128         LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
129         LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
130         ProcessFirstHalf = LhsIsSelfAdjoint
131         && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
132         || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
133         || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
134         ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
135     };
136 
137     //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
138     for (Index col = 0; col < rhs.cols(); col++) {
139         for (Index row = 0; row < lhs.rows(); row++) {
140             dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
141         }
142     }
143     //Use matrix lower triangular part
144     for (Index row = 0; row < lhs.rows(); row++) {
145         typename _Lhs::InnerLowerIterator lIt(lhs, row);
146         const Index stop = lIt.col() + lIt.size();
147         for (Index col = 0; col < rhs.cols(); col++) {
148 
149             Index k = lIt.col();
150             Scalar tmp = 0;
151             while (k < stop) {
152                 tmp +=
153                         lIt.value() *
154                         rhs(k++, col);
155                 ++lIt;
156             }
157             dst(row, col) += tmp;
158             lIt += -lIt.size();
159         }
160 
161     }
162 
163     //Use matrix upper triangular part
164     for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
165         typename _Lhs::InnerUpperIterator uIt(lhs, lhscol);
166         const Index stop = uIt.size() + uIt.row();
167         for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
168 
169 
170             const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
171             Index k = uIt.row();
172             while (k < stop) {
173                 dst(k++, rhscol) +=
174                         uIt.value() *
175                         rhsCoeff;
176                 ++uIt;
177             }
178             uIt += -uIt.size();
179         }
180     }
181 
182 }
183 
184 template<typename Lhs, typename Rhs, typename Dest>
185 EIGEN_DONT_INLINE void skyline_col_major_time_dense_product(const Lhs& lhs, const Rhs& rhs, Dest& dst) {
186     typedef typename remove_all<Lhs>::type _Lhs;
187     typedef typename remove_all<Rhs>::type _Rhs;
188     typedef typename traits<Lhs>::Scalar Scalar;
189 
190     enum {
191         LhsIsRowMajor = (_Lhs::Flags & RowMajorBit) == RowMajorBit,
192         LhsIsSelfAdjoint = (_Lhs::Flags & SelfAdjointBit) == SelfAdjointBit,
193         ProcessFirstHalf = LhsIsSelfAdjoint
194         && (((_Lhs::Flags & (UpperTriangularBit | LowerTriangularBit)) == 0)
195         || ((_Lhs::Flags & UpperTriangularBit) && !LhsIsRowMajor)
196         || ((_Lhs::Flags & LowerTriangularBit) && LhsIsRowMajor)),
197         ProcessSecondHalf = LhsIsSelfAdjoint && (!ProcessFirstHalf)
198     };
199 
200     //Use matrix diagonal part <- Improvement : use inner iterator on dense matrix.
201     for (Index col = 0; col < rhs.cols(); col++) {
202         for (Index row = 0; row < lhs.rows(); row++) {
203             dst(row, col) = lhs.coeffDiag(row) * rhs(row, col);
204         }
205     }
206 
207     //Use matrix upper triangular part
208     for (Index row = 0; row < lhs.rows(); row++) {
209         typename _Lhs::InnerUpperIterator uIt(lhs, row);
210         const Index stop = uIt.col() + uIt.size();
211         for (Index col = 0; col < rhs.cols(); col++) {
212 
213             Index k = uIt.col();
214             Scalar tmp = 0;
215             while (k < stop) {
216                 tmp +=
217                         uIt.value() *
218                         rhs(k++, col);
219                 ++uIt;
220             }
221 
222 
223             dst(row, col) += tmp;
224             uIt += -uIt.size();
225         }
226     }
227 
228     //Use matrix lower triangular part
229     for (Index lhscol = 0; lhscol < lhs.cols(); lhscol++) {
230         typename _Lhs::InnerLowerIterator lIt(lhs, lhscol);
231         const Index stop = lIt.size() + lIt.row();
232         for (Index rhscol = 0; rhscol < rhs.cols(); rhscol++) {
233 
234             const Scalar rhsCoeff = rhs.coeff(lhscol, rhscol);
235             Index k = lIt.row();
236             while (k < stop) {
237                 dst(k++, rhscol) +=
238                         lIt.value() *
239                         rhsCoeff;
240                 ++lIt;
241             }
242             lIt += -lIt.size();
243         }
244     }
245 
246 }
247 
248 template<typename Lhs, typename Rhs, typename ResultType,
249         int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit>
250         struct skyline_product_selector;
251 
252 template<typename Lhs, typename Rhs, typename ResultType>
253 struct skyline_product_selector<Lhs, Rhs, ResultType, RowMajor> {
254     typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
255 
256     static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
257         skyline_row_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
258     }
259 };
260 
261 template<typename Lhs, typename Rhs, typename ResultType>
262 struct skyline_product_selector<Lhs, Rhs, ResultType, ColMajor> {
263     typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
264 
265     static void run(const Lhs& lhs, const Rhs& rhs, ResultType & res) {
266         skyline_col_major_time_dense_product<Lhs, Rhs, ResultType > (lhs, rhs, res);
267     }
268 };
269 
270 } // end namespace internal
271 
272 // template<typename Derived>
273 // template<typename Lhs, typename Rhs >
274 // Derived & MatrixBase<Derived>::lazyAssign(const SkylineProduct<Lhs, Rhs, SkylineTimeDenseProduct>& product) {
275 //     typedef typename internal::remove_all<Lhs>::type _Lhs;
276 //     internal::skyline_product_selector<typename internal::remove_all<Lhs>::type,
277 //             typename internal::remove_all<Rhs>::type,
278 //             Derived>::run(product.lhs(), product.rhs(), derived());
279 //
280 //     return derived();
281 // }
282 
283 // skyline * dense
284 
285 template<typename Derived>
286 template<typename OtherDerived >
287 EIGEN_STRONG_INLINE const typename SkylineProductReturnType<Derived, OtherDerived>::Type
288 SkylineMatrixBase<Derived>::operator*(const MatrixBase<OtherDerived> &other) const {
289 
290     return typename SkylineProductReturnType<Derived, OtherDerived>::Type(derived(), other.derived());
291 }
292 
293 } // end namespace Eigen
294 
295 #endif // EIGEN_SKYLINEPRODUCT_H
296