1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 template <typename Dimensions, typename Scalar> 18 class TensorLazyBaseEvaluator { 19 public: TensorLazyBaseEvaluator()20 TensorLazyBaseEvaluator() : m_refcount(0) { } ~TensorLazyBaseEvaluator()21 virtual ~TensorLazyBaseEvaluator() { } 22 23 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0; 24 EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0; 25 26 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0; 27 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0; 28 incrRefCount()29 void incrRefCount() { ++m_refcount; } decrRefCount()30 void decrRefCount() { --m_refcount; } refCount()31 int refCount() const { return m_refcount; } 32 33 private: 34 // No copy, no assigment; 35 TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other); 36 TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other); 37 38 int m_refcount; 39 }; 40 41 42 template <typename Dimensions, typename Expr, typename Device> 43 class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> { 44 public: 45 // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions; 46 typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar; 47 TensorLazyEvaluatorReadOnly(const Expr & expr,const Device & device)48 TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) { 49 m_dims = m_impl.dimensions(); 50 m_impl.evalSubExprsIfNeeded(NULL); 51 } ~TensorLazyEvaluatorReadOnly()52 virtual ~TensorLazyEvaluatorReadOnly() { 53 m_impl.cleanup(); 54 } 55 dimensions()56 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const { 57 return m_dims; 58 } data()59 EIGEN_DEVICE_FUNC virtual const Scalar* data() const { 60 return m_impl.data(); 61 } 62 coeff(DenseIndex index)63 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const { 64 return m_impl.coeff(index); 65 } coeffRef(DenseIndex)66 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) { 67 eigen_assert(false && "can't reference the coefficient of a rvalue"); 68 return m_dummy; 69 }; 70 71 protected: 72 TensorEvaluator<Expr, Device> m_impl; 73 Dimensions m_dims; 74 Scalar m_dummy; 75 }; 76 77 template <typename Dimensions, typename Expr, typename Device> 78 class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> { 79 public: 80 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base; 81 typedef typename Base::Scalar Scalar; 82 TensorLazyEvaluatorWritable(const Expr & expr,const Device & device)83 TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) { 84 } ~TensorLazyEvaluatorWritable()85 virtual ~TensorLazyEvaluatorWritable() { 86 } 87 coeffRef(DenseIndex index)88 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { 89 return this->m_impl.coeffRef(index); 90 } 91 }; 92 93 template <typename Dimensions, typename Expr, typename Device> 94 class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value), 95 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, 96 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type { 97 public: 98 typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value), 99 TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, 100 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base; 101 typedef typename Base::Scalar Scalar; 102 TensorLazyEvaluator(const Expr & expr,const Device & device)103 TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) { 104 } ~TensorLazyEvaluator()105 virtual ~TensorLazyEvaluator() { 106 } 107 }; 108 109 } // namespace internal 110 111 112 /** \class TensorRef 113 * \ingroup CXX11_Tensor_Module 114 * 115 * \brief A reference to a tensor expression 116 * The expression will be evaluated lazily (as much as possible). 117 * 118 */ 119 template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> > 120 { 121 public: 122 typedef TensorRef<PlainObjectType> Self; 123 typedef typename PlainObjectType::Base Base; 124 typedef typename Eigen::internal::nested<Self>::type Nested; 125 typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind; 126 typedef typename internal::traits<PlainObjectType>::Index Index; 127 typedef typename internal::traits<PlainObjectType>::Scalar Scalar; 128 typedef typename NumTraits<Scalar>::Real RealScalar; 129 typedef typename Base::CoeffReturnType CoeffReturnType; 130 typedef Scalar* PointerType; 131 typedef PointerType PointerArgType; 132 133 static const Index NumIndices = PlainObjectType::NumIndices; 134 typedef typename PlainObjectType::Dimensions Dimensions; 135 136 enum { 137 IsAligned = false, 138 PacketAccess = false, 139 Layout = PlainObjectType::Layout, 140 CoordAccess = false, // to be implemented 141 RawAccess = false 142 }; 143 TensorRef()144 EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) { 145 } 146 147 template <typename Expression> TensorRef(const Expression & expr)148 EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) { 149 m_evaluator->incrRefCount(); 150 } 151 152 template <typename Expression> 153 EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) { 154 unrefEvaluator(); 155 m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice()); 156 m_evaluator->incrRefCount(); 157 return *this; 158 } 159 ~TensorRef()160 ~TensorRef() { 161 unrefEvaluator(); 162 } 163 TensorRef(const TensorRef & other)164 TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) { 165 eigen_assert(m_evaluator->refCount() > 0); 166 m_evaluator->incrRefCount(); 167 } 168 169 TensorRef& operator = (const TensorRef& other) { 170 if (this != &other) { 171 unrefEvaluator(); 172 m_evaluator = other.m_evaluator; 173 eigen_assert(m_evaluator->refCount() > 0); 174 m_evaluator->incrRefCount(); 175 } 176 return *this; 177 } 178 179 EIGEN_DEVICE_FUNC rank()180 EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); } 181 EIGEN_DEVICE_FUNC dimension(Index n)182 EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } 183 EIGEN_DEVICE_FUNC dimensions()184 EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } 185 EIGEN_DEVICE_FUNC size()186 EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); } 187 EIGEN_DEVICE_FUNC data()188 EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); } 189 190 EIGEN_DEVICE_FUNC operator()191 EIGEN_STRONG_INLINE const Scalar operator()(Index index) const 192 { 193 return m_evaluator->coeff(index); 194 } 195 196 #if EIGEN_HAS_VARIADIC_TEMPLATES 197 template<typename... IndexTypes> EIGEN_DEVICE_FUNC operator()198 EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const 199 { 200 const std::size_t num_indices = (sizeof...(otherIndices) + 1); 201 const array<Index, num_indices> indices{{firstIndex, otherIndices...}}; 202 return coeff(indices); 203 } 204 template<typename... IndexTypes> EIGEN_DEVICE_FUNC coeffRef(Index firstIndex,IndexTypes...otherIndices)205 EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) 206 { 207 const std::size_t num_indices = (sizeof...(otherIndices) + 1); 208 const array<Index, num_indices> indices{{firstIndex, otherIndices...}}; 209 return coeffRef(indices); 210 } 211 #else 212 213 EIGEN_DEVICE_FUNC operator()214 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const 215 { 216 array<Index, 2> indices; 217 indices[0] = i0; 218 indices[1] = i1; 219 return coeff(indices); 220 } 221 EIGEN_DEVICE_FUNC operator()222 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const 223 { 224 array<Index, 3> indices; 225 indices[0] = i0; 226 indices[1] = i1; 227 indices[2] = i2; 228 return coeff(indices); 229 } 230 EIGEN_DEVICE_FUNC operator()231 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const 232 { 233 array<Index, 4> indices; 234 indices[0] = i0; 235 indices[1] = i1; 236 indices[2] = i2; 237 indices[3] = i3; 238 return coeff(indices); 239 } 240 EIGEN_DEVICE_FUNC operator()241 EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const 242 { 243 array<Index, 5> indices; 244 indices[0] = i0; 245 indices[1] = i1; 246 indices[2] = i2; 247 indices[3] = i3; 248 indices[4] = i4; 249 return coeff(indices); 250 } 251 EIGEN_DEVICE_FUNC coeffRef(Index i0,Index i1)252 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1) 253 { 254 array<Index, 2> indices; 255 indices[0] = i0; 256 indices[1] = i1; 257 return coeffRef(indices); 258 } 259 EIGEN_DEVICE_FUNC coeffRef(Index i0,Index i1,Index i2)260 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2) 261 { 262 array<Index, 3> indices; 263 indices[0] = i0; 264 indices[1] = i1; 265 indices[2] = i2; 266 return coeffRef(indices); 267 } 268 EIGEN_DEVICE_FUNC operator()269 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) 270 { 271 array<Index, 4> indices; 272 indices[0] = i0; 273 indices[1] = i1; 274 indices[2] = i2; 275 indices[3] = i3; 276 return coeffRef(indices); 277 } 278 EIGEN_DEVICE_FUNC coeffRef(Index i0,Index i1,Index i2,Index i3,Index i4)279 EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4) 280 { 281 array<Index, 5> indices; 282 indices[0] = i0; 283 indices[1] = i1; 284 indices[2] = i2; 285 indices[3] = i3; 286 indices[4] = i4; 287 return coeffRef(indices); 288 } 289 #endif 290 291 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC coeff(const array<Index,NumIndices> & indices)292 EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const 293 { 294 const Dimensions& dims = this->dimensions(); 295 Index index = 0; 296 if (PlainObjectType::Options & RowMajor) { 297 index += indices[0]; 298 for (size_t i = 1; i < NumIndices; ++i) { 299 index = index * dims[i] + indices[i]; 300 } 301 } else { 302 index += indices[NumIndices-1]; 303 for (int i = NumIndices-2; i >= 0; --i) { 304 index = index * dims[i] + indices[i]; 305 } 306 } 307 return m_evaluator->coeff(index); 308 } 309 template <std::size_t NumIndices> EIGEN_DEVICE_FUNC coeffRef(const array<Index,NumIndices> & indices)310 EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) 311 { 312 const Dimensions& dims = this->dimensions(); 313 Index index = 0; 314 if (PlainObjectType::Options & RowMajor) { 315 index += indices[0]; 316 for (size_t i = 1; i < NumIndices; ++i) { 317 index = index * dims[i] + indices[i]; 318 } 319 } else { 320 index += indices[NumIndices-1]; 321 for (int i = NumIndices-2; i >= 0; --i) { 322 index = index * dims[i] + indices[i]; 323 } 324 } 325 return m_evaluator->coeffRef(index); 326 } 327 328 EIGEN_DEVICE_FUNC coeff(Index index)329 EIGEN_STRONG_INLINE const Scalar coeff(Index index) const 330 { 331 return m_evaluator->coeff(index); 332 } 333 334 EIGEN_DEVICE_FUNC coeffRef(Index index)335 EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) 336 { 337 return m_evaluator->coeffRef(index); 338 } 339 340 private: unrefEvaluator()341 EIGEN_STRONG_INLINE void unrefEvaluator() { 342 if (m_evaluator) { 343 m_evaluator->decrRefCount(); 344 if (m_evaluator->refCount() == 0) { 345 delete m_evaluator; 346 } 347 } 348 } 349 350 internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator; 351 }; 352 353 354 // evaluator for rvalues 355 template<typename Derived, typename Device> 356 struct TensorEvaluator<const TensorRef<Derived>, Device> 357 { 358 typedef typename Derived::Index Index; 359 typedef typename Derived::Scalar Scalar; 360 typedef typename Derived::Scalar CoeffReturnType; 361 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 362 typedef typename Derived::Dimensions Dimensions; 363 364 enum { 365 IsAligned = false, 366 PacketAccess = false, 367 Layout = TensorRef<Derived>::Layout, 368 CoordAccess = false, // to be implemented 369 RawAccess = false 370 }; 371 372 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) 373 : m_ref(m) 374 { } 375 376 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); } 377 378 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { 379 return true; 380 } 381 382 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } 383 384 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 385 return m_ref.coeff(index); 386 } 387 388 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { 389 return m_ref.coeffRef(index); 390 } 391 392 EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); } 393 394 protected: 395 TensorRef<Derived> m_ref; 396 }; 397 398 399 // evaluator for lvalues 400 template<typename Derived, typename Device> 401 struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device> 402 { 403 typedef typename Derived::Index Index; 404 typedef typename Derived::Scalar Scalar; 405 typedef typename Derived::Scalar CoeffReturnType; 406 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 407 typedef typename Derived::Dimensions Dimensions; 408 409 typedef TensorEvaluator<const TensorRef<Derived>, Device> Base; 410 411 enum { 412 IsAligned = false, 413 PacketAccess = false, 414 RawAccess = false 415 }; 416 417 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d) 418 { } 419 420 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { 421 return this->m_ref.coeffRef(index); 422 } 423 }; 424 425 426 427 } // end namespace Eigen 428 429 #endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H 430