1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H 11 #define EIGEN_SPARSE_CWISE_BINARY_OP_H 12 13 namespace Eigen { 14 15 // Here we have to handle 3 cases: 16 // 1 - sparse op dense 17 // 2 - dense op sparse 18 // 3 - sparse op sparse 19 // We also need to implement a 4th iterator for: 20 // 4 - dense op dense 21 // Finally, we also need to distinguish between the product and other operations : 22 // configuration returned mode 23 // 1 - sparse op dense product sparse 24 // generic dense 25 // 2 - dense op sparse product sparse 26 // generic dense 27 // 3 - sparse op sparse product sparse 28 // generic sparse 29 // 4 - dense op dense product dense 30 // generic dense 31 32 namespace internal { 33 34 template<> struct promote_storage_type<Dense,Sparse> 35 { typedef Sparse ret; }; 36 37 template<> struct promote_storage_type<Sparse,Dense> 38 { typedef Sparse ret; }; 39 40 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived, 41 typename _LhsStorageMode = typename traits<Lhs>::StorageKind, 42 typename _RhsStorageMode = typename traits<Rhs>::StorageKind> 43 class sparse_cwise_binary_op_inner_iterator_selector; 44 45 } // end namespace internal 46 47 template<typename BinaryOp, typename Lhs, typename Rhs> 48 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> 49 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 50 { 51 public: 52 class InnerIterator; 53 class ReverseInnerIterator; 54 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived; 55 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived) 56 CwiseBinaryOpImpl() 57 { 58 typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind; 59 typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind; 60 EIGEN_STATIC_ASSERT(( 61 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value) 62 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))), 63 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH); 64 } 65 }; 66 67 template<typename BinaryOp, typename Lhs, typename Rhs> 68 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator 69 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator> 70 { 71 public: 72 typedef typename Lhs::Index Index; 73 typedef internal::sparse_cwise_binary_op_inner_iterator_selector< 74 BinaryOp,Lhs,Rhs, InnerIterator> Base; 75 76 // NOTE: we have to prefix Index by "typename Lhs::" to avoid an ICE with VC11 77 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename Lhs::Index outer) 78 : Base(binOp.derived(),outer) 79 {} 80 }; 81 82 /*************************************************************************** 83 * Implementation of inner-iterators 84 ***************************************************************************/ 85 86 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; }; 87 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; }; 88 89 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any ! 90 91 namespace internal { 92 93 // sparse - sparse (generic) 94 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived> 95 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse> 96 { 97 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr; 98 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar; 99 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 100 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 101 typedef typename _LhsNested::InnerIterator LhsIterator; 102 typedef typename _RhsNested::InnerIterator RhsIterator; 103 typedef typename Lhs::Index Index; 104 105 public: 106 107 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 108 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 109 { 110 this->operator++(); 111 } 112 113 EIGEN_STRONG_INLINE Derived& operator++() 114 { 115 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index())) 116 { 117 m_id = m_lhsIter.index(); 118 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value()); 119 ++m_lhsIter; 120 ++m_rhsIter; 121 } 122 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index()))) 123 { 124 m_id = m_lhsIter.index(); 125 m_value = m_functor(m_lhsIter.value(), Scalar(0)); 126 ++m_lhsIter; 127 } 128 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index()))) 129 { 130 m_id = m_rhsIter.index(); 131 m_value = m_functor(Scalar(0), m_rhsIter.value()); 132 ++m_rhsIter; 133 } 134 else 135 { 136 m_value = 0; // this is to avoid a compilation warning 137 m_id = -1; 138 } 139 return *static_cast<Derived*>(this); 140 } 141 142 EIGEN_STRONG_INLINE Scalar value() const { return m_value; } 143 144 EIGEN_STRONG_INLINE Index index() const { return m_id; } 145 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); } 146 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); } 147 148 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; } 149 150 protected: 151 LhsIterator m_lhsIter; 152 RhsIterator m_rhsIter; 153 const BinaryOp& m_functor; 154 Scalar m_value; 155 Index m_id; 156 }; 157 158 // sparse - sparse (product) 159 template<typename T, typename Lhs, typename Rhs, typename Derived> 160 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse> 161 { 162 typedef scalar_product_op<T> BinaryFunc; 163 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 164 typedef typename CwiseBinaryXpr::Scalar Scalar; 165 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 166 typedef typename _LhsNested::InnerIterator LhsIterator; 167 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 168 typedef typename _RhsNested::InnerIterator RhsIterator; 169 typedef typename Lhs::Index Index; 170 public: 171 172 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 173 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 174 { 175 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 176 { 177 if (m_lhsIter.index() < m_rhsIter.index()) 178 ++m_lhsIter; 179 else 180 ++m_rhsIter; 181 } 182 } 183 184 EIGEN_STRONG_INLINE Derived& operator++() 185 { 186 ++m_lhsIter; 187 ++m_rhsIter; 188 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 189 { 190 if (m_lhsIter.index() < m_rhsIter.index()) 191 ++m_lhsIter; 192 else 193 ++m_rhsIter; 194 } 195 return *static_cast<Derived*>(this); 196 } 197 198 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); } 199 200 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 201 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 202 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 203 204 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); } 205 206 protected: 207 LhsIterator m_lhsIter; 208 RhsIterator m_rhsIter; 209 const BinaryFunc& m_functor; 210 }; 211 212 // sparse - dense (product) 213 template<typename T, typename Lhs, typename Rhs, typename Derived> 214 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense> 215 { 216 typedef scalar_product_op<T> BinaryFunc; 217 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 218 typedef typename CwiseBinaryXpr::Scalar Scalar; 219 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 220 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested; 221 typedef typename _LhsNested::InnerIterator LhsIterator; 222 typedef typename Lhs::Index Index; 223 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; 224 public: 225 226 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 227 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer) 228 {} 229 230 EIGEN_STRONG_INLINE Derived& operator++() 231 { 232 ++m_lhsIter; 233 return *static_cast<Derived*>(this); 234 } 235 236 EIGEN_STRONG_INLINE Scalar value() const 237 { return m_functor(m_lhsIter.value(), 238 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); } 239 240 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 241 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 242 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 243 244 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; } 245 246 protected: 247 RhsNested m_rhs; 248 LhsIterator m_lhsIter; 249 const BinaryFunc m_functor; 250 const Index m_outer; 251 }; 252 253 // sparse - dense (product) 254 template<typename T, typename Lhs, typename Rhs, typename Derived> 255 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse> 256 { 257 typedef scalar_product_op<T> BinaryFunc; 258 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 259 typedef typename CwiseBinaryXpr::Scalar Scalar; 260 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 261 typedef typename _RhsNested::InnerIterator RhsIterator; 262 typedef typename Lhs::Index Index; 263 264 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; 265 public: 266 267 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 268 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer) 269 {} 270 271 EIGEN_STRONG_INLINE Derived& operator++() 272 { 273 ++m_rhsIter; 274 return *static_cast<Derived*>(this); 275 } 276 277 EIGEN_STRONG_INLINE Scalar value() const 278 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); } 279 280 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); } 281 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); } 282 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); } 283 284 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; } 285 286 protected: 287 const CwiseBinaryXpr& m_xpr; 288 RhsIterator m_rhsIter; 289 const BinaryFunc& m_functor; 290 const Index m_outer; 291 }; 292 293 } // end namespace internal 294 295 /*************************************************************************** 296 * Implementation of SparseMatrixBase and SparseCwise functions/operators 297 ***************************************************************************/ 298 299 template<typename Derived> 300 template<typename OtherDerived> 301 EIGEN_STRONG_INLINE Derived & 302 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other) 303 { 304 return derived() = derived() - other.derived(); 305 } 306 307 template<typename Derived> 308 template<typename OtherDerived> 309 EIGEN_STRONG_INLINE Derived & 310 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other) 311 { 312 return derived() = derived() + other.derived(); 313 } 314 315 template<typename Derived> 316 template<typename OtherDerived> 317 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE 318 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const 319 { 320 return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived()); 321 } 322 323 } // end namespace Eigen 324 325 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H 326