1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H 12 13 namespace Eigen { 14 15 /** \class TensorGenerator 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor generator class. 19 * 20 * 21 */ 22 namespace internal { 23 template<typename Generator, typename XprType> 24 struct traits<TensorGeneratorOp<Generator, XprType> > : public traits<XprType> 25 { 26 typedef typename XprType::Scalar Scalar; 27 typedef traits<XprType> XprTraits; 28 typedef typename XprTraits::StorageKind StorageKind; 29 typedef typename XprTraits::Index Index; 30 typedef typename XprType::Nested Nested; 31 typedef typename remove_reference<Nested>::type _Nested; 32 static const int NumDimensions = XprTraits::NumDimensions; 33 static const int Layout = XprTraits::Layout; 34 }; 35 36 template<typename Generator, typename XprType> 37 struct eval<TensorGeneratorOp<Generator, XprType>, Eigen::Dense> 38 { 39 typedef const TensorGeneratorOp<Generator, XprType>& type; 40 }; 41 42 template<typename Generator, typename XprType> 43 struct nested<TensorGeneratorOp<Generator, XprType>, 1, typename eval<TensorGeneratorOp<Generator, XprType> >::type> 44 { 45 typedef TensorGeneratorOp<Generator, XprType> type; 46 }; 47 48 } // end namespace internal 49 50 51 52 template<typename Generator, typename XprType> 53 class TensorGeneratorOp : public TensorBase<TensorGeneratorOp<Generator, XprType>, ReadOnlyAccessors> 54 { 55 public: 56 typedef typename Eigen::internal::traits<TensorGeneratorOp>::Scalar Scalar; 57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 58 typedef typename XprType::CoeffReturnType CoeffReturnType; 59 typedef typename Eigen::internal::nested<TensorGeneratorOp>::type Nested; 60 typedef typename Eigen::internal::traits<TensorGeneratorOp>::StorageKind StorageKind; 61 typedef typename Eigen::internal::traits<TensorGeneratorOp>::Index Index; 62 63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorGeneratorOp(const XprType& expr, const Generator& generator) 64 : m_xpr(expr), m_generator(generator) {} 65 66 EIGEN_DEVICE_FUNC 67 const Generator& generator() const { return m_generator; } 68 69 EIGEN_DEVICE_FUNC 70 const typename internal::remove_all<typename XprType::Nested>::type& 71 expression() const { return m_xpr; } 72 73 protected: 74 typename XprType::Nested m_xpr; 75 const Generator m_generator; 76 }; 77 78 79 // Eval as rvalue 80 template<typename Generator, typename ArgType, typename Device> 81 struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device> 82 { 83 typedef TensorGeneratorOp<Generator, ArgType> XprType; 84 typedef typename XprType::Index Index; 85 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; 86 static const int NumDims = internal::array_size<Dimensions>::value; 87 typedef typename XprType::Scalar Scalar; 88 typedef typename XprType::CoeffReturnType CoeffReturnType; 89 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 90 enum { 91 IsAligned = false, 92 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1), 93 BlockAccess = false, 94 Layout = TensorEvaluator<ArgType, Device>::Layout, 95 CoordAccess = false, // to be implemented 96 RawAccess = false 97 }; 98 99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 100 : m_generator(op.generator()) 101 { 102 TensorEvaluator<ArgType, Device> impl(op.expression(), device); 103 m_dimensions = impl.dimensions(); 104 105 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 106 m_strides[0] = 1; 107 for (int i = 1; i < NumDims; ++i) { 108 m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1]; 109 } 110 } else { 111 m_strides[NumDims - 1] = 1; 112 for (int i = NumDims - 2; i >= 0; --i) { 113 m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1]; 114 } 115 } 116 } 117 118 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 119 120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { 121 return true; 122 } 123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 124 } 125 126 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const 127 { 128 array<Index, NumDims> coords; 129 extract_coordinates(index, coords); 130 return m_generator(coords); 131 } 132 133 template<int LoadMode> 134 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const 135 { 136 const int packetSize = internal::unpacket_traits<PacketReturnType>::size; 137 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) 138 eigen_assert(index+packetSize-1 < dimensions().TotalSize()); 139 140 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[packetSize]; 141 for (int i = 0; i < packetSize; ++i) { 142 values[i] = coeff(index+i); 143 } 144 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 145 return rslt; 146 } 147 148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 149 costPerCoeff(bool) const { 150 // TODO(rmlarsen): This is just a placeholder. Define interface to make 151 // generators return their cost. 152 return TensorOpCost(0, 0, TensorOpCost::AddCost<Scalar>() + 153 TensorOpCost::MulCost<Scalar>()); 154 } 155 156 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 157 158 protected: 159 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 160 void extract_coordinates(Index index, array<Index, NumDims>& coords) const { 161 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 162 for (int i = NumDims - 1; i > 0; --i) { 163 const Index idx = index / m_strides[i]; 164 index -= idx * m_strides[i]; 165 coords[i] = idx; 166 } 167 coords[0] = index; 168 } else { 169 for (int i = 0; i < NumDims - 1; ++i) { 170 const Index idx = index / m_strides[i]; 171 index -= idx * m_strides[i]; 172 coords[i] = idx; 173 } 174 coords[NumDims-1] = index; 175 } 176 } 177 178 Dimensions m_dimensions; 179 array<Index, NumDims> m_strides; 180 Generator m_generator; 181 }; 182 183 } // end namespace Eigen 184 185 #endif // EIGEN_CXX11_TENSOR_TENSOR_GENERATOR_H 186