• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5 // Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_DIAGONALPRODUCT_H
12 #define EIGEN_DIAGONALPRODUCT_H
13 
14 namespace Eigen {
15 
16 namespace internal {
17 template<typename MatrixType, typename DiagonalType, int ProductOrder>
18 struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
19  : traits<MatrixType>
20 {
21   typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
22   enum {
23     RowsAtCompileTime = MatrixType::RowsAtCompileTime,
24     ColsAtCompileTime = MatrixType::ColsAtCompileTime,
25     MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
26     MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
27 
28     _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
29     _PacketOnDiag = !((int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
30                     ||(int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)),
31     _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
32     // FIXME currently we need same types, but in the future the next rule should be the one
33     //_Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagonalType::Flags)&PacketAccessBit))),
34     _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && ((!_PacketOnDiag) || (bool(int(DiagonalType::Flags)&PacketAccessBit))),
35 
36     Flags = (HereditaryBits & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),
37     CoeffReadCost = NumTraits<Scalar>::MulCost + MatrixType::CoeffReadCost + DiagonalType::DiagonalVectorType::CoeffReadCost
38   };
39 };
40 }
41 
42 template<typename MatrixType, typename DiagonalType, int ProductOrder>
43 class DiagonalProduct : internal::no_assignment_operator,
44                         public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
45 {
46   public:
47 
48     typedef MatrixBase<DiagonalProduct> Base;
49     EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
50 
51     inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
52       : m_matrix(matrix), m_diagonal(diagonal)
53     {
54       eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
55     }
56 
57     inline Index rows() const { return m_matrix.rows(); }
58     inline Index cols() const { return m_matrix.cols(); }
59 
60     const Scalar coeff(Index row, Index col) const
61     {
62       return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
63     }
64 
65     template<int LoadMode>
66     EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
67     {
68       enum {
69         StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
70       };
71       const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
72 
73       return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
74         ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
75        ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
76     }
77 
78   protected:
79     template<int LoadMode>
80     EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
81     {
82       return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
83                      internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
84     }
85 
86     template<int LoadMode>
87     EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
88     {
89       enum {
90         InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
91         DiagonalVectorPacketLoadMode = (LoadMode == Aligned && ((InnerSize%16) == 0)) ? Aligned : Unaligned
92       };
93       return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
94                      m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
95     }
96 
97     typename MatrixType::Nested m_matrix;
98     typename DiagonalType::Nested m_diagonal;
99 };
100 
101 /** \returns the diagonal matrix product of \c *this by the diagonal matrix \a diagonal.
102   */
103 template<typename Derived>
104 template<typename DiagonalDerived>
105 inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
106 MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &diagonal) const
107 {
108   return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), diagonal.derived());
109 }
110 
111 /** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
112   */
113 template<typename DiagonalDerived>
114 template<typename MatrixDerived>
115 inline const DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>
116 DiagonalBase<DiagonalDerived>::operator*(const MatrixBase<MatrixDerived> &matrix) const
117 {
118   return DiagonalProduct<MatrixDerived, DiagonalDerived, OnTheLeft>(matrix.derived(), derived());
119 }
120 
121 } // end namespace Eigen
122 
123 #endif // EIGEN_DIAGONALPRODUCT_H
124