• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
12 
13 
14 namespace Eigen {
15 
16 /** \internal
17   *
18   * \class TensorDimensions
19   * \ingroup CXX11_Tensor_Module
20   *
21   * \brief Set of classes used to encode and store the dimensions of a Tensor.
22   *
23   * The Sizes class encodes as part of the type the number of dimensions and the
24   * sizes corresponding to each dimension. It uses no storage space since it is
25   * entirely known at compile time.
26   * The DSizes class is its dynamic sibling: the number of dimensions is known
27   * at compile time but the sizes are set during execution.
28   *
29   * \sa Tensor
30   */
31 
32 // Boilerplate code
33 namespace internal {
34 
35 template<std::size_t n, typename Dimension> struct dget {
36   static const std::size_t value = get<n, Dimension>::value;
37 };
38 
39 
40 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
41 struct fixed_size_tensor_index_linearization_helper
42 {
43   template <typename Dimensions> EIGEN_DEVICE_FUNC
runfixed_size_tensor_index_linearization_helper44   static inline Index run(array<Index, NumIndices> const& indices,
45                           const Dimensions& dimensions)
46   {
47     return array_get<RowMajor ? n - 1 : (NumIndices - n)>(indices) +
48         dget<RowMajor ? n - 1 : (NumIndices - n), Dimensions>::value *
49         fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
50   }
51 };
52 
53 template<typename Index, std::size_t NumIndices, bool RowMajor>
54 struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
55 {
56   template <typename Dimensions> EIGEN_DEVICE_FUNC
57   static inline Index run(array<Index, NumIndices> const&, const Dimensions&)
58   {
59     return 0;
60   }
61 };
62 
63 template<typename Index, std::size_t n>
64 struct fixed_size_tensor_index_extraction_helper
65 {
66   template <typename Dimensions> EIGEN_DEVICE_FUNC
67   static inline Index run(const Index index,
68                           const Dimensions& dimensions)
69   {
70     const Index mult = (index == n-1) ? 1 : 0;
71     return array_get<n-1>(dimensions) * mult +
72         fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
73   }
74 };
75 
76 template<typename Index>
77 struct fixed_size_tensor_index_extraction_helper<Index, 0>
78 {
79   template <typename Dimensions> EIGEN_DEVICE_FUNC
80   static inline Index run(const Index,
81                           const Dimensions&)
82   {
83     return 0;
84   }
85   };
86 
87 }  // end namespace internal
88 
89 
90 // Fixed size
91 #ifndef EIGEN_EMULATE_CXX11_META_H
92 template <typename std::ptrdiff_t... Indices>
93 struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
94   typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
95   static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
96 
97   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const {
98     return Base::count;
99   }
100 
101   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() {
102     return internal::arg_prod(Indices...);
103   }
104 
105   EIGEN_DEVICE_FUNC Sizes() { }
106   template <typename DenseIndex>
107   explicit EIGEN_DEVICE_FUNC Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
108     // todo: add assertion
109   }
110 #if EIGEN_HAS_VARIADIC_TEMPLATES
111   template <typename... DenseIndex> EIGEN_DEVICE_FUNC Sizes(DenseIndex...) { }
112   explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) {
113     // todo: add assertion
114   }
115 #endif
116 
117   template <typename T> Sizes& operator = (const T& /*other*/) {
118     // add assertion failure if the size of other is different
119     return *this;
120   }
121 
122   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::size_t index) const {
123     return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, *this);
124   }
125 
126   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
127   size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
128     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *static_cast<const Base*>(this));
129   }
130   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
131   size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
132     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *static_cast<const Base*>(this));
133   }
134 };
135 
136 namespace internal {
137 template <typename std::ptrdiff_t... Indices>
138 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
139   return Sizes<Indices...>::total_size;
140 }
141 }
142 
143 #else
144 
145 template <std::size_t n>
146 struct non_zero_size {
147   typedef internal::type2val<std::size_t, n> type;
148 };
149 template <>
150 struct non_zero_size<0> {
151   typedef internal::null_type type;
152 };
153 
154 template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0, std::size_t V5=0> struct Sizes {
155   typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base;
156   static const size_t count = Base::count;
157   static const std::size_t total_size = internal::arg_prod<Base>::value;
158 
159   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
160     return count;
161   }
162 
163   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() {
164     return internal::arg_prod<Base>::value;
165   }
166 
167   Sizes() { }
168   template <typename DenseIndex>
169   explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
170     // todo: add assertion
171   }
172   template <typename T> Sizes& operator = (const T& /*other*/) {
173     // add assertion failure if the size of other is different
174     return *this;
175   }
176 
177 #if EIGEN_HAS_VARIADIC_TEMPLATES
178   template <typename... DenseIndex> Sizes(DenseIndex... /*indices*/) { }
179   explicit Sizes(std::initializer_list<std::size_t>) {
180     // todo: add assertion
181   }
182 #else
183   EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) {
184   }
185   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex) {
186   }
187   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex) {
188   }
189   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
190   }
191   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
192   }
193 #endif
194 
195   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
196     switch (index) {
197       case 0:
198         return internal::get<0, Base>::value;
199       case 1:
200         return internal::get<1, Base>::value;
201       case 2:
202         return internal::get<2, Base>::value;
203       case 3:
204         return internal::get<3, Base>::value;
205       case 4:
206         return internal::get<4, Base>::value;
207       default:
208         eigen_assert(false && "index overflow");
209         return static_cast<DenseIndex>(-1);
210     }
211   }
212 
213   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
214   size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
215     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *reinterpret_cast<const Base*>(this));
216   }
217   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
218   size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
219     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *reinterpret_cast<const Base*>(this));
220   }
221 };
222 
223 namespace internal {
224 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
225 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
226   return Sizes<V1, V2, V3, V4, V5>::total_size;
227 }
228 }
229 
230 #endif
231 
232 // Boilerplate
233 namespace internal {
234 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
235 struct tensor_index_linearization_helper
236 {
237   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
238   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const& dimensions)
239   {
240     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
241       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
242         tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
243   }
244 };
245 
246 template<typename Index, std::size_t NumIndices, bool RowMajor>
247 struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
248 {
249   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
250   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const&)
251   {
252     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
253   }
254 };
255 }  // end namespace internal
256 
257 
258 
259 // Dynamic size
260 template <typename DenseIndex, int NumDims>
261 struct DSizes : array<DenseIndex, NumDims> {
262   typedef array<DenseIndex, NumDims> Base;
263   static const int count = NumDims;
264 
265   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
266     return NumDims;
267   }
268 
269   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const {
270     return (NumDims == 0) ? 1 : internal::array_prod(*static_cast<const Base*>(this));
271   }
272 
273   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() {
274     for (int i = 0 ; i < NumDims; ++i) {
275       (*this)[i] = 0;
276     }
277   }
278   EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
279 
280   EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
281     eigen_assert(NumDims == 1);
282     (*this)[0] = i0;
283   }
284 
285 #if EIGEN_HAS_VARIADIC_TEMPLATES
286   template<typename... IndexTypes> EIGEN_DEVICE_FUNC
287   EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension, IndexTypes... otherDimensions) : Base({{firstDimension, secondDimension, otherDimensions...}}) {
288     EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
289   }
290 #else
291   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1) {
292     eigen_assert(NumDims == 2);
293     (*this)[0] = i0;
294     (*this)[1] = i1;
295   }
296   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
297     eigen_assert(NumDims == 3);
298     (*this)[0] = i0;
299     (*this)[1] = i1;
300     (*this)[2] = i2;
301   }
302   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
303     eigen_assert(NumDims == 4);
304     (*this)[0] = i0;
305     (*this)[1] = i1;
306     (*this)[2] = i2;
307     (*this)[3] = i3;
308   }
309   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
310     eigen_assert(NumDims == 5);
311     (*this)[0] = i0;
312     (*this)[1] = i1;
313     (*this)[2] = i2;
314     (*this)[3] = i3;
315     (*this)[4] = i4;
316   }
317 #endif
318 
319   EIGEN_DEVICE_FUNC DSizes& operator = (const array<DenseIndex, NumDims>& other) {
320     *static_cast<Base*>(this) = other;
321     return *this;
322   }
323 
324   // A constexpr would be so much better here
325   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
326     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
327   }
328   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
329     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
330   }
331 };
332 
333 
334 
335 
336 // Boilerplate
337 namespace internal {
338 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
339 struct tensor_vsize_index_linearization_helper
340 {
341   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
342   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions)
343   {
344     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
345       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
346         tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
347   }
348 };
349 
350 template<typename Index, std::size_t NumIndices, bool RowMajor>
351 struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
352 {
353   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
354   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&)
355   {
356     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
357   }
358 };
359 }  // end namespace internal
360 
361 
362 namespace internal {
363 
364 template <typename DenseIndex, int NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {
365   static const size_t value = NumDims;
366 };
367 template <typename DenseIndex, int NumDims> struct array_size<DSizes<DenseIndex, NumDims> > {
368   static const size_t value = NumDims;
369 };
370 #ifndef EIGEN_EMULATE_CXX11_META_H
371 template <typename std::ptrdiff_t... Indices> struct array_size<const Sizes<Indices...> > {
372 static const std::ptrdiff_t value = Sizes<Indices...>::count;
373 };
374 template <typename std::ptrdiff_t... Indices> struct array_size<Sizes<Indices...> > {
375 static const std::ptrdiff_t value = Sizes<Indices...>::count;
376 };
377 template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) {
378   return get<n, internal::numeric_list<std::size_t, Indices...> >::value;
379 }
380 template <std::ptrdiff_t n> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) {
381   eigen_assert(false && "should never be called");
382   return -1;
383 }
384 #else
385 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
386   static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
387 };
388 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > {
389   static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
390 };
391 template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<V1,V2,V3,V4,V5>&) {
392   return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value;
393 }
394 
395 #endif
396 
397 
398 template <typename Dims1, typename Dims2, size_t n, size_t m>
399 struct sizes_match_below_dim {
400   static EIGEN_DEVICE_FUNC  inline bool run(Dims1&, Dims2&) {
401     return false;
402   }
403 };
404 template <typename Dims1, typename Dims2, size_t n>
405 struct sizes_match_below_dim<Dims1, Dims2, n, n> {
406   static EIGEN_DEVICE_FUNC  inline bool run(Dims1& dims1, Dims2& dims2) {
407     return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &
408         sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
409   }
410 };
411 template <typename Dims1, typename Dims2>
412 struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
413   static EIGEN_DEVICE_FUNC  inline bool run(Dims1&, Dims2&) {
414     return true;
415   }
416 };
417 
418 } // end namespace internal
419 
420 
421 template <typename Dims1, typename Dims2>
422 EIGEN_DEVICE_FUNC bool dimensions_match(Dims1& dims1, Dims2& dims2) {
423   return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
424 }
425 
426 } // end namespace Eigen
427 
428 #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
429