• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2013 Christian Seiler <christian@iwakd.de>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
11 #define EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename list> struct tensor_static_symgroup_permutate;
18 
19 template<int... nn>
20 struct tensor_static_symgroup_permutate<numeric_list<int, nn...>>
21 {
22   constexpr static std::size_t N = sizeof...(nn);
23 
24   template<typename T>
25   constexpr static inline std::array<T, N> run(const std::array<T, N>& indices)
26   {
27     return {{indices[nn]...}};
28   }
29 };
30 
31 template<typename indices_, int flags_>
32 struct tensor_static_symgroup_element
33 {
34   typedef indices_ indices;
35   constexpr static int flags = flags_;
36 };
37 
38 template<typename Gen, int N>
39 struct tensor_static_symgroup_element_ctor
40 {
41   typedef tensor_static_symgroup_element<
42     typename gen_numeric_list_swapped_pair<int, N, Gen::One, Gen::Two>::type,
43     Gen::Flags
44   > type;
45 };
46 
47 template<int N>
48 struct tensor_static_symgroup_identity_ctor
49 {
50   typedef tensor_static_symgroup_element<
51     typename gen_numeric_list<int, N>::type,
52     0
53   > type;
54 };
55 
56 template<typename iib>
57 struct tensor_static_symgroup_multiply_helper
58 {
59   template<int... iia>
60   constexpr static inline numeric_list<int, get<iia, iib>::value...> helper(numeric_list<int, iia...>) {
61     return numeric_list<int, get<iia, iib>::value...>();
62   }
63 };
64 
65 template<typename A, typename B>
66 struct tensor_static_symgroup_multiply
67 {
68   private:
69     typedef typename A::indices iia;
70     typedef typename B::indices iib;
71     constexpr static int ffa = A::flags;
72     constexpr static int ffb = B::flags;
73 
74   public:
75     static_assert(iia::count == iib::count, "Cannot multiply symmetry elements with different number of indices.");
76 
77     typedef tensor_static_symgroup_element<
78       decltype(tensor_static_symgroup_multiply_helper<iib>::helper(iia())),
79       ffa ^ ffb
80     > type;
81 };
82 
83 template<typename A, typename B>
84 struct tensor_static_symgroup_equality
85 {
86     typedef typename A::indices iia;
87     typedef typename B::indices iib;
88     constexpr static int ffa = A::flags;
89     constexpr static int ffb = B::flags;
90     static_assert(iia::count == iib::count, "Cannot compare symmetry elements with different number of indices.");
91 
92     constexpr static bool value = is_same<iia, iib>::value;
93 
94   private:
95     /* this should be zero if they are identical, or else the tensor
96      * will be forced to be pure real, pure imaginary or even pure zero
97      */
98     constexpr static int flags_cmp_ = ffa ^ ffb;
99 
100     /* either they are not equal, then we don't care whether the flags
101      * match, or they are equal, and then we have to check
102      */
103     constexpr static bool is_zero      = value && flags_cmp_ == NegationFlag;
104     constexpr static bool is_real      = value && flags_cmp_ == ConjugationFlag;
105     constexpr static bool is_imag      = value && flags_cmp_ == (NegationFlag | ConjugationFlag);
106 
107   public:
108     constexpr static int global_flags =
109       (is_real ? GlobalRealFlag : 0) |
110       (is_imag ? GlobalImagFlag : 0) |
111       (is_zero ? GlobalZeroFlag : 0);
112 };
113 
114 template<std::size_t NumIndices, typename... Gen>
115 struct tensor_static_symgroup
116 {
117   typedef StaticSGroup<Gen...> type;
118   constexpr static std::size_t size = type::static_size;
119 };
120 
121 template<typename Index, std::size_t N, int... ii, int... jj>
122 constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>)
123 {
124   return {{ idx[ii]..., idx[jj]... }};
125 }
126 
127 template<typename Index, int... ii>
128 static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>)
129 {
130   std::vector<Index> result{{ idx[ii]... }};
131   std::size_t target_size = idx.size();
132   for (std::size_t i = result.size(); i < target_size; i++)
133     result.push_back(idx[i]);
134   return result;
135 }
136 
137 template<typename T> struct tensor_static_symgroup_do_apply;
138 
139 template<typename first, typename... next>
140 struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>>
141 {
142   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
143   static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
144   {
145     static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices.");
146     typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices;
147     initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...);
148     return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
149   }
150 
151   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
152   static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args)
153   {
154     eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
155     initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
156     return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
157   }
158 };
159 
160 template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)>
161 struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>>
162 {
163   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
164   static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...)
165   {
166     // do nothing
167     return initial;
168   }
169 
170   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
171   static inline RV run(const std::vector<Index>&, RV initial, Args&&...)
172   {
173     // do nothing
174     return initial;
175   }
176 };
177 
178 } // end namespace internal
179 
180 template<typename... Gen>
181 class StaticSGroup
182 {
183     constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
184     typedef internal::group_theory::enumerate_group_elements<
185       internal::tensor_static_symgroup_multiply,
186       internal::tensor_static_symgroup_equality,
187       typename internal::tensor_static_symgroup_identity_ctor<NumIndices>::type,
188       internal::type_list<typename internal::tensor_static_symgroup_element_ctor<Gen, NumIndices>::type...>
189     > group_elements;
190     typedef typename group_elements::type ge;
191   public:
192     constexpr inline StaticSGroup() {}
193     constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {}
194     constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {}
195 
196     template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
197     static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args)
198     {
199       return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
200     }
201 
202     template<typename Op, typename RV, typename Index, typename... Args>
203     static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args)
204     {
205       eigen_assert(idx.size() == NumIndices);
206       return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
207     }
208 
209     constexpr static std::size_t static_size = ge::count;
210 
211     constexpr static inline std::size_t size() {
212       return ge::count;
213     }
214     constexpr static inline int globalFlags() { return group_elements::global_flags; }
215 
216     template<typename Tensor_, typename... IndexTypes>
217     inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
218     {
219       static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
220       return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
221     }
222 
223     template<typename Tensor_>
224     inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
225     {
226       return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
227     }
228 };
229 
230 } // end namespace Eigen
231 
232 #endif // EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
233 
234 /*
235  * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
236  */
237