1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H 12 13 namespace Eigen { 14 15 /** \class TensorMap 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief A tensor expression mapping an existing array of data. 19 * 20 */ 21 /// template <class> class MakePointer_ is added to convert the host pointer to the device pointer. 22 /// It is added due to the fact that for our device compiler T* is not allowed. 23 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer T. 24 /// This is done through our MakePointer_ class. By default the Type in the MakePointer_<T> is T* . 25 /// Therefore, by adding the default value, we managed to convert the type and it does not break any 26 /// existing code as its default value is T*. 27 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> > 28 { 29 public: 30 typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self; 31 typedef typename PlainObjectType::Base Base; 32 typedef typename Eigen::internal::nested<Self>::type Nested; 33 typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind; 34 typedef typename internal::traits<PlainObjectType>::Index Index; 35 typedef typename internal::traits<PlainObjectType>::Scalar Scalar; 36 typedef typename NumTraits<Scalar>::Real RealScalar; 37 typedef typename Base::CoeffReturnType CoeffReturnType; 38 39 /* typedef typename internal::conditional< 40 bool(internal::is_lvalue<PlainObjectType>::value), 41 Scalar *, 42 const Scalar *>::type 43 PointerType;*/ 44 typedef typename MakePointer_<Scalar>::Type PointerType; 45 typedef PointerType PointerArgType; 46 47 static const int Options = Options_; 48 49 static const Index NumIndices = PlainObjectType::NumIndices; 50 typedef typename PlainObjectType::Dimensions Dimensions; 51 52 enum { 53 IsAligned = ((int(Options_)&Aligned)==Aligned), 54 Layout = PlainObjectType::Layout, 55 CoordAccess = true, 56 RawAccess = true 57 }; 58 59 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr)60 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() { 61 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 62 EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 63 } 64 65 #if EIGEN_HAS_VARIADIC_TEMPLATES 66 template<typename... IndexTypes> EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index firstDimension,IndexTypes...otherDimensions)67 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) { 68 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 69 EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 70 } 71 #else 72 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index firstDimension)73 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) { 74 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 75 EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 76 } 77 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2)78 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) { 79 EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 80 } 81 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3)82 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) { 83 EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 84 } 85 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4)86 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) { 87 EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 88 } 89 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4,Index dim5)90 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) { 91 EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 92 } 93 #endif 94 TensorMap(PointerArgType dataPtr,const array<Index,NumIndices> & dimensions)95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions) 96 : m_data(dataPtr), m_dimensions(dimensions) 97 { } 98 99 template <typename Dimensions> TensorMap(PointerArgType dataPtr,const Dimensions & dimensions)100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions) 101 : m_data(dataPtr), m_dimensions(dimensions) 102 { } 103 TensorMap(PlainObjectType & tensor)104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor) 105 : m_data(tensor.data()), m_dimensions(tensor.dimensions()) 106 { } 107 108 EIGEN_DEVICE_FUNC rank()109 EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); } 110 EIGEN_DEVICE_FUNC dimension(Index n)111 EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } 112 EIGEN_DEVICE_FUNC dimensions()113 EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 114 EIGEN_DEVICE_FUNC size()115 EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } 116 EIGEN_DEVICE_FUNC data()117 EIGEN_STRONG_INLINE PointerType data() { return m_data; } 118 EIGEN_DEVICE_FUNC data()119 EIGEN_STRONG_INLINE const PointerType data() const { return m_data; } 120 121 EIGEN_DEVICE_FUNC operator()122 EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const 123 { 124 // eigen_assert(checkIndexRange(indices)); 125 if (PlainObjectType::Options&RowMajor) { 126 const Index index = m_dimensions.IndexOfRowMajor(indices); 127 return m_data[index]; 128 } else { 129 const Index index = m_dimensions.IndexOfColMajor(indices); 130 return m_data[index]; 131 } 132 } 133 134 EIGEN_DEVICE_FUNC operator()135 EIGEN_STRONG_INLINE const Scalar& operator()() const 136 { 137 EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE) 138 return m_data[0]; 139 } 140 141 EIGEN_DEVICE_FUNC operator()142 EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const 143 { 144 eigen_internal_assert(index >= 0 && index < size()); 145 return m_data[index]; 146 } 147 148 #if EIGEN_HAS_VARIADIC_TEMPLATES 149 template<typename... IndexTypes> EIGEN_DEVICE_FUNC operator()150 EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const 151 { 152 EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) 153 if (PlainObjectType::Options&RowMajor) { 154 const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}}); 155 return m_data[index]; 156 } else { 157 const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}}); 158 return m_data[index]; 159 } 160 } 161 #else 162 EIGEN_DEVICE_FUNC operator()163 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const 164 { 165 if (PlainObjectType::Options&RowMajor) { 166 const Index index = i1 + i0 * m_dimensions[1]; 167 return m_data[index]; 168 } else { 169 const Index index = i0 + i1 * m_dimensions[0]; 170 return m_data[index]; 171 } 172 } 173 EIGEN_DEVICE_FUNC operator()174 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const 175 { 176 if (PlainObjectType::Options&RowMajor) { 177 const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0); 178 return m_data[index]; 179 } else { 180 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); 181 return m_data[index]; 182 } 183 } 184 EIGEN_DEVICE_FUNC operator()185 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const 186 { 187 if (PlainObjectType::Options&RowMajor) { 188 const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)); 189 return m_data[index]; 190 } else { 191 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3)); 192 return m_data[index]; 193 } 194 } 195 EIGEN_DEVICE_FUNC operator()196 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const 197 { 198 if (PlainObjectType::Options&RowMajor) { 199 const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0))); 200 return m_data[index]; 201 } else { 202 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4))); 203 return m_data[index]; 204 } 205 } 206 #endif 207 208 EIGEN_DEVICE_FUNC operator()209 EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices) 210 { 211 // eigen_assert(checkIndexRange(indices)); 212 if (PlainObjectType::Options&RowMajor) { 213 const Index index = m_dimensions.IndexOfRowMajor(indices); 214 return m_data[index]; 215 } else { 216 const Index index = m_dimensions.IndexOfColMajor(indices); 217 return m_data[index]; 218 } 219 } 220 221 EIGEN_DEVICE_FUNC operator()222 EIGEN_STRONG_INLINE Scalar& operator()() 223 { 224 EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE) 225 return m_data[0]; 226 } 227 228 EIGEN_DEVICE_FUNC operator()229 EIGEN_STRONG_INLINE Scalar& operator()(Index index) 230 { 231 eigen_internal_assert(index >= 0 && index < size()); 232 return m_data[index]; 233 } 234 235 #if EIGEN_HAS_VARIADIC_TEMPLATES 236 template<typename... IndexTypes> EIGEN_DEVICE_FUNC operator()237 EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) 238 { 239 static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); 240 const std::size_t NumDims = sizeof...(otherIndices) + 2; 241 if (PlainObjectType::Options&RowMajor) { 242 const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}}); 243 return m_data[index]; 244 } else { 245 const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}}); 246 return m_data[index]; 247 } 248 } 249 #else 250 EIGEN_DEVICE_FUNC operator()251 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1) 252 { 253 if (PlainObjectType::Options&RowMajor) { 254 const Index index = i1 + i0 * m_dimensions[1]; 255 return m_data[index]; 256 } else { 257 const Index index = i0 + i1 * m_dimensions[0]; 258 return m_data[index]; 259 } 260 } 261 EIGEN_DEVICE_FUNC operator()262 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2) 263 { 264 if (PlainObjectType::Options&RowMajor) { 265 const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0); 266 return m_data[index]; 267 } else { 268 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); 269 return m_data[index]; 270 } 271 } 272 EIGEN_DEVICE_FUNC operator()273 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) 274 { 275 if (PlainObjectType::Options&RowMajor) { 276 const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)); 277 return m_data[index]; 278 } else { 279 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3)); 280 return m_data[index]; 281 } 282 } 283 EIGEN_DEVICE_FUNC operator()284 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) 285 { 286 if (PlainObjectType::Options&RowMajor) { 287 const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0))); 288 return m_data[index]; 289 } else { 290 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4))); 291 return m_data[index]; 292 } 293 } 294 #endif 295 296 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other) 297 { 298 typedef TensorAssignOp<Self, const Self> Assign; 299 Assign assign(*this, other); 300 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 301 return *this; 302 } 303 304 template<typename OtherDerived> 305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 306 Self& operator=(const OtherDerived& other) 307 { 308 typedef TensorAssignOp<Self, const OtherDerived> Assign; 309 Assign assign(*this, other); 310 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 311 return *this; 312 } 313 314 private: 315 typename MakePointer_<Scalar>::Type m_data; 316 Dimensions m_dimensions; 317 }; 318 319 } // end namespace Eigen 320 321 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H 322