• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 using Eigen::array;
16 
17 template <int DataLayout>
test_simple_shuffling()18 static void test_simple_shuffling()
19 {
20   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
21   tensor.setRandom();
22   array<ptrdiff_t, 4> shuffles;
23   shuffles[0] = 0;
24   shuffles[1] = 1;
25   shuffles[2] = 2;
26   shuffles[3] = 3;
27 
28   Tensor<float, 4, DataLayout> no_shuffle;
29   no_shuffle = tensor.shuffle(shuffles);
30 
31   VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
32   VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
33   VERIFY_IS_EQUAL(no_shuffle.dimension(2), 5);
34   VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
35 
36   for (int i = 0; i < 2; ++i) {
37     for (int j = 0; j < 3; ++j) {
38       for (int k = 0; k < 5; ++k) {
39         for (int l = 0; l < 7; ++l) {
40           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
41         }
42       }
43     }
44   }
45 
46   shuffles[0] = 2;
47   shuffles[1] = 3;
48   shuffles[2] = 1;
49   shuffles[3] = 0;
50   Tensor<float, 4, DataLayout> shuffle;
51   shuffle = tensor.shuffle(shuffles);
52 
53   VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
54   VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
55   VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
56   VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
57 
58   for (int i = 0; i < 2; ++i) {
59     for (int j = 0; j < 3; ++j) {
60       for (int k = 0; k < 5; ++k) {
61         for (int l = 0; l < 7; ++l) {
62           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
63         }
64       }
65     }
66   }
67 }
68 
69 
70 template <int DataLayout>
test_expr_shuffling()71 static void test_expr_shuffling()
72 {
73   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
74   tensor.setRandom();
75 
76   array<ptrdiff_t, 4> shuffles;
77   shuffles[0] = 2;
78   shuffles[1] = 3;
79   shuffles[2] = 1;
80   shuffles[3] = 0;
81   Tensor<float, 4, DataLayout> expected;
82   expected = tensor.shuffle(shuffles);
83 
84   Tensor<float, 4, DataLayout> result(5,7,3,2);
85 
86   array<int, 4> src_slice_dim{{2,3,1,7}};
87   array<int, 4> src_slice_start{{0,0,0,0}};
88   array<int, 4> dst_slice_dim{{1,7,3,2}};
89   array<int, 4> dst_slice_start{{0,0,0,0}};
90 
91   for (int i = 0; i < 5; ++i) {
92     result.slice(dst_slice_start, dst_slice_dim) =
93         tensor.slice(src_slice_start, src_slice_dim).shuffle(shuffles);
94     src_slice_start[2] += 1;
95     dst_slice_start[0] += 1;
96   }
97 
98   VERIFY_IS_EQUAL(result.dimension(0), 5);
99   VERIFY_IS_EQUAL(result.dimension(1), 7);
100   VERIFY_IS_EQUAL(result.dimension(2), 3);
101   VERIFY_IS_EQUAL(result.dimension(3), 2);
102 
103   for (int i = 0; i < expected.dimension(0); ++i) {
104     for (int j = 0; j < expected.dimension(1); ++j) {
105       for (int k = 0; k < expected.dimension(2); ++k) {
106         for (int l = 0; l < expected.dimension(3); ++l) {
107           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
108         }
109       }
110     }
111   }
112 
113   dst_slice_start[0] = 0;
114   result.setRandom();
115   for (int i = 0; i < 5; ++i) {
116     result.slice(dst_slice_start, dst_slice_dim) =
117         tensor.shuffle(shuffles).slice(dst_slice_start, dst_slice_dim);
118     dst_slice_start[0] += 1;
119   }
120 
121   for (int i = 0; i < expected.dimension(0); ++i) {
122     for (int j = 0; j < expected.dimension(1); ++j) {
123       for (int k = 0; k < expected.dimension(2); ++k) {
124         for (int l = 0; l < expected.dimension(3); ++l) {
125           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
126         }
127       }
128     }
129   }
130 }
131 
132 
133 template <int DataLayout>
test_shuffling_as_value()134 static void test_shuffling_as_value()
135 {
136   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
137   tensor.setRandom();
138   array<ptrdiff_t, 4> shuffles;
139   shuffles[2] = 0;
140   shuffles[3] = 1;
141   shuffles[1] = 2;
142   shuffles[0] = 3;
143   Tensor<float, 4, DataLayout> shuffle(5,7,3,2);
144   shuffle.shuffle(shuffles) = tensor;
145 
146   VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
147   VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
148   VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
149   VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
150 
151   for (int i = 0; i < 2; ++i) {
152     for (int j = 0; j < 3; ++j) {
153       for (int k = 0; k < 5; ++k) {
154         for (int l = 0; l < 7; ++l) {
155           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
156         }
157       }
158     }
159   }
160 
161   array<ptrdiff_t, 4> no_shuffle;
162   no_shuffle[0] = 0;
163   no_shuffle[1] = 1;
164   no_shuffle[2] = 2;
165   no_shuffle[3] = 3;
166   Tensor<float, 4, DataLayout> shuffle2(5,7,3,2);
167   shuffle2.shuffle(shuffles) = tensor.shuffle(no_shuffle);
168   for (int i = 0; i < 5; ++i) {
169     for (int j = 0; j < 7; ++j) {
170       for (int k = 0; k < 3; ++k) {
171         for (int l = 0; l < 2; ++l) {
172           VERIFY_IS_EQUAL(shuffle2(i,j,k,l), shuffle(i,j,k,l));
173         }
174       }
175     }
176   }
177 }
178 
179 
180 template <int DataLayout>
test_shuffle_unshuffle()181 static void test_shuffle_unshuffle()
182 {
183   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
184   tensor.setRandom();
185 
186   // Choose a random permutation.
187   array<ptrdiff_t, 4> shuffles;
188   for (int i = 0; i < 4; ++i) {
189     shuffles[i] = i;
190   }
191   array<ptrdiff_t, 4> shuffles_inverse;
192   for (int i = 0; i < 4; ++i) {
193     const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
194     shuffles_inverse[shuffles[index]] = i;
195     std::swap(shuffles[i], shuffles[index]);
196   }
197 
198   Tensor<float, 4, DataLayout> shuffle;
199   shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
200 
201   VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
202   VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
203   VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
204   VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
205 
206   for (int i = 0; i < 2; ++i) {
207     for (int j = 0; j < 3; ++j) {
208       for (int k = 0; k < 5; ++k) {
209         for (int l = 0; l < 7; ++l) {
210           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
211         }
212       }
213     }
214   }
215 }
216 
217 
test_cxx11_tensor_shuffling()218 void test_cxx11_tensor_shuffling()
219 {
220   CALL_SUBTEST(test_simple_shuffling<ColMajor>());
221   CALL_SUBTEST(test_simple_shuffling<RowMajor>());
222   CALL_SUBTEST(test_expr_shuffling<ColMajor>());
223   CALL_SUBTEST(test_expr_shuffling<RowMajor>());
224   CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
225   CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
226   CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
227   CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
228 }
229