• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H
12 
13 namespace Eigen {
14 
15 /** \class TensorMap
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief A tensor expression mapping an existing array of data.
19   *
20   */
21 /// template <class> class MakePointer_ is added to convert the host pointer to the device pointer.
22 /// It is added due to the fact that for our device compiler T* is not allowed.
23 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer T.
24 /// This is done through our MakePointer_ class. By default the Type in the MakePointer_<T> is T* .
25 /// Therefore, by adding the default value, we managed to convert the type and it does not break any
26 /// existing code as its default value is T*.
27 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> >
28 {
29   public:
30     typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self;
31     typedef typename PlainObjectType::Base Base;
32     typedef typename Eigen::internal::nested<Self>::type Nested;
33     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
34     typedef typename internal::traits<PlainObjectType>::Index Index;
35     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
36     typedef typename NumTraits<Scalar>::Real RealScalar;
37     typedef typename Base::CoeffReturnType CoeffReturnType;
38 
39   /*    typedef typename internal::conditional<
40                          bool(internal::is_lvalue<PlainObjectType>::value),
41                          Scalar *,
42                          const Scalar *>::type
43                      PointerType;*/
44     typedef typename MakePointer_<Scalar>::Type PointerType;
45     typedef PointerType PointerArgType;
46 
47     static const int Options = Options_;
48 
49     static const Index NumIndices = PlainObjectType::NumIndices;
50     typedef typename PlainObjectType::Dimensions Dimensions;
51 
52     enum {
53       IsAligned = ((int(Options_)&Aligned)==Aligned),
54       Layout = PlainObjectType::Layout,
55       CoordAccess = true,
56       RawAccess = true
57     };
58 
59     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr)60     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() {
61       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
62       EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
63     }
64 
65 #if EIGEN_HAS_VARIADIC_TEMPLATES
66     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index firstDimension,IndexTypes...otherDimensions)67     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
68       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
69       EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
70     }
71 #else
72     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index firstDimension)73     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
74       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
75       EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
76     }
77     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2)78     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
79       EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
80     }
81     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3)82     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
83       EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
84     }
85     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4)86     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
87       EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
88     }
89     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4,Index dim5)90     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
91       EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
92     }
93 #endif
94 
TensorMap(PointerArgType dataPtr,const array<Index,NumIndices> & dimensions)95    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
96       : m_data(dataPtr), m_dimensions(dimensions)
97     { }
98 
99     template <typename Dimensions>
TensorMap(PointerArgType dataPtr,const Dimensions & dimensions)100     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
101       : m_data(dataPtr), m_dimensions(dimensions)
102     { }
103 
TensorMap(PlainObjectType & tensor)104     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
105       : m_data(tensor.data()), m_dimensions(tensor.dimensions())
106     { }
107 
108     EIGEN_DEVICE_FUNC
rank()109     EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
110     EIGEN_DEVICE_FUNC
dimension(Index n)111     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
112     EIGEN_DEVICE_FUNC
dimensions()113     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
114     EIGEN_DEVICE_FUNC
size()115     EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
116     EIGEN_DEVICE_FUNC
data()117     EIGEN_STRONG_INLINE PointerType data() { return m_data; }
118     EIGEN_DEVICE_FUNC
data()119     EIGEN_STRONG_INLINE const PointerType data() const { return m_data; }
120 
121     EIGEN_DEVICE_FUNC
operator()122     EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
123     {
124       //      eigen_assert(checkIndexRange(indices));
125       if (PlainObjectType::Options&RowMajor) {
126         const Index index = m_dimensions.IndexOfRowMajor(indices);
127         return m_data[index];
128       } else {
129         const Index index = m_dimensions.IndexOfColMajor(indices);
130         return m_data[index];
131       }
132     }
133 
134     EIGEN_DEVICE_FUNC
operator()135     EIGEN_STRONG_INLINE const Scalar& operator()() const
136     {
137       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
138       return m_data[0];
139     }
140 
141     EIGEN_DEVICE_FUNC
operator()142     EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
143     {
144       eigen_internal_assert(index >= 0 && index < size());
145       return m_data[index];
146     }
147 
148 #if EIGEN_HAS_VARIADIC_TEMPLATES
149     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()150     EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const
151     {
152       EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
153       if (PlainObjectType::Options&RowMajor) {
154         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
155         return m_data[index];
156       } else {
157         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
158         return m_data[index];
159       }
160     }
161 #else
162     EIGEN_DEVICE_FUNC
operator()163     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
164     {
165       if (PlainObjectType::Options&RowMajor) {
166         const Index index = i1 + i0 * m_dimensions[1];
167         return m_data[index];
168       } else {
169         const Index index = i0 + i1 * m_dimensions[0];
170         return m_data[index];
171       }
172     }
173     EIGEN_DEVICE_FUNC
operator()174     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
175     {
176       if (PlainObjectType::Options&RowMajor) {
177          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
178          return m_data[index];
179       } else {
180          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
181         return m_data[index];
182       }
183     }
184     EIGEN_DEVICE_FUNC
operator()185     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const
186     {
187       if (PlainObjectType::Options&RowMajor) {
188         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
189         return m_data[index];
190       } else {
191         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
192         return m_data[index];
193       }
194     }
195     EIGEN_DEVICE_FUNC
operator()196     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
197     {
198       if (PlainObjectType::Options&RowMajor) {
199         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
200         return m_data[index];
201       } else {
202         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
203         return m_data[index];
204       }
205     }
206 #endif
207 
208     EIGEN_DEVICE_FUNC
operator()209     EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
210     {
211       //      eigen_assert(checkIndexRange(indices));
212       if (PlainObjectType::Options&RowMajor) {
213         const Index index = m_dimensions.IndexOfRowMajor(indices);
214         return m_data[index];
215       } else {
216         const Index index = m_dimensions.IndexOfColMajor(indices);
217         return m_data[index];
218       }
219     }
220 
221     EIGEN_DEVICE_FUNC
operator()222     EIGEN_STRONG_INLINE Scalar& operator()()
223     {
224       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
225       return m_data[0];
226     }
227 
228     EIGEN_DEVICE_FUNC
operator()229     EIGEN_STRONG_INLINE Scalar& operator()(Index index)
230     {
231       eigen_internal_assert(index >= 0 && index < size());
232       return m_data[index];
233     }
234 
235 #if EIGEN_HAS_VARIADIC_TEMPLATES
236     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()237     EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
238     {
239       static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
240       const std::size_t NumDims = sizeof...(otherIndices) + 2;
241       if (PlainObjectType::Options&RowMajor) {
242         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
243         return m_data[index];
244       } else {
245         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
246         return m_data[index];
247       }
248     }
249 #else
250     EIGEN_DEVICE_FUNC
operator()251     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
252     {
253        if (PlainObjectType::Options&RowMajor) {
254          const Index index = i1 + i0 * m_dimensions[1];
255         return m_data[index];
256       } else {
257         const Index index = i0 + i1 * m_dimensions[0];
258         return m_data[index];
259       }
260     }
261     EIGEN_DEVICE_FUNC
operator()262     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
263     {
264        if (PlainObjectType::Options&RowMajor) {
265          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
266         return m_data[index];
267       } else {
268          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
269         return m_data[index];
270       }
271     }
272     EIGEN_DEVICE_FUNC
operator()273     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
274     {
275       if (PlainObjectType::Options&RowMajor) {
276         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
277         return m_data[index];
278       } else {
279         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
280         return m_data[index];
281       }
282     }
283     EIGEN_DEVICE_FUNC
operator()284     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
285     {
286       if (PlainObjectType::Options&RowMajor) {
287         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
288         return m_data[index];
289       } else {
290         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
291         return m_data[index];
292       }
293     }
294 #endif
295 
296     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other)
297     {
298       typedef TensorAssignOp<Self, const Self> Assign;
299       Assign assign(*this, other);
300       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
301       return *this;
302     }
303 
304     template<typename OtherDerived>
305     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
306     Self& operator=(const OtherDerived& other)
307     {
308       typedef TensorAssignOp<Self, const OtherDerived> Assign;
309       Assign assign(*this, other);
310       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
311       return *this;
312     }
313 
314   private:
315     typename MakePointer_<Scalar>::Type m_data;
316     Dimensions m_dimensions;
317 };
318 
319 } // end namespace Eigen
320 
321 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H
322