1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15 using Eigen::array;
16
17 template <int DataLayout>
test_simple_shuffling()18 static void test_simple_shuffling()
19 {
20 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
21 tensor.setRandom();
22 array<ptrdiff_t, 4> shuffles;
23 shuffles[0] = 0;
24 shuffles[1] = 1;
25 shuffles[2] = 2;
26 shuffles[3] = 3;
27
28 Tensor<float, 4, DataLayout> no_shuffle;
29 no_shuffle = tensor.shuffle(shuffles);
30
31 VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
32 VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
33 VERIFY_IS_EQUAL(no_shuffle.dimension(2), 5);
34 VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
35
36 for (int i = 0; i < 2; ++i) {
37 for (int j = 0; j < 3; ++j) {
38 for (int k = 0; k < 5; ++k) {
39 for (int l = 0; l < 7; ++l) {
40 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
41 }
42 }
43 }
44 }
45
46 shuffles[0] = 2;
47 shuffles[1] = 3;
48 shuffles[2] = 1;
49 shuffles[3] = 0;
50 Tensor<float, 4, DataLayout> shuffle;
51 shuffle = tensor.shuffle(shuffles);
52
53 VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
54 VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
55 VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
56 VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
57
58 for (int i = 0; i < 2; ++i) {
59 for (int j = 0; j < 3; ++j) {
60 for (int k = 0; k < 5; ++k) {
61 for (int l = 0; l < 7; ++l) {
62 VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
63 }
64 }
65 }
66 }
67 }
68
69
70 template <int DataLayout>
test_expr_shuffling()71 static void test_expr_shuffling()
72 {
73 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
74 tensor.setRandom();
75
76 array<ptrdiff_t, 4> shuffles;
77 shuffles[0] = 2;
78 shuffles[1] = 3;
79 shuffles[2] = 1;
80 shuffles[3] = 0;
81 Tensor<float, 4, DataLayout> expected;
82 expected = tensor.shuffle(shuffles);
83
84 Tensor<float, 4, DataLayout> result(5,7,3,2);
85
86 array<int, 4> src_slice_dim{{2,3,1,7}};
87 array<int, 4> src_slice_start{{0,0,0,0}};
88 array<int, 4> dst_slice_dim{{1,7,3,2}};
89 array<int, 4> dst_slice_start{{0,0,0,0}};
90
91 for (int i = 0; i < 5; ++i) {
92 result.slice(dst_slice_start, dst_slice_dim) =
93 tensor.slice(src_slice_start, src_slice_dim).shuffle(shuffles);
94 src_slice_start[2] += 1;
95 dst_slice_start[0] += 1;
96 }
97
98 VERIFY_IS_EQUAL(result.dimension(0), 5);
99 VERIFY_IS_EQUAL(result.dimension(1), 7);
100 VERIFY_IS_EQUAL(result.dimension(2), 3);
101 VERIFY_IS_EQUAL(result.dimension(3), 2);
102
103 for (int i = 0; i < expected.dimension(0); ++i) {
104 for (int j = 0; j < expected.dimension(1); ++j) {
105 for (int k = 0; k < expected.dimension(2); ++k) {
106 for (int l = 0; l < expected.dimension(3); ++l) {
107 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
108 }
109 }
110 }
111 }
112
113 dst_slice_start[0] = 0;
114 result.setRandom();
115 for (int i = 0; i < 5; ++i) {
116 result.slice(dst_slice_start, dst_slice_dim) =
117 tensor.shuffle(shuffles).slice(dst_slice_start, dst_slice_dim);
118 dst_slice_start[0] += 1;
119 }
120
121 for (int i = 0; i < expected.dimension(0); ++i) {
122 for (int j = 0; j < expected.dimension(1); ++j) {
123 for (int k = 0; k < expected.dimension(2); ++k) {
124 for (int l = 0; l < expected.dimension(3); ++l) {
125 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
126 }
127 }
128 }
129 }
130 }
131
132
133 template <int DataLayout>
test_shuffling_as_value()134 static void test_shuffling_as_value()
135 {
136 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
137 tensor.setRandom();
138 array<ptrdiff_t, 4> shuffles;
139 shuffles[2] = 0;
140 shuffles[3] = 1;
141 shuffles[1] = 2;
142 shuffles[0] = 3;
143 Tensor<float, 4, DataLayout> shuffle(5,7,3,2);
144 shuffle.shuffle(shuffles) = tensor;
145
146 VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
147 VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
148 VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
149 VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
150
151 for (int i = 0; i < 2; ++i) {
152 for (int j = 0; j < 3; ++j) {
153 for (int k = 0; k < 5; ++k) {
154 for (int l = 0; l < 7; ++l) {
155 VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
156 }
157 }
158 }
159 }
160
161 array<ptrdiff_t, 4> no_shuffle;
162 no_shuffle[0] = 0;
163 no_shuffle[1] = 1;
164 no_shuffle[2] = 2;
165 no_shuffle[3] = 3;
166 Tensor<float, 4, DataLayout> shuffle2(5,7,3,2);
167 shuffle2.shuffle(shuffles) = tensor.shuffle(no_shuffle);
168 for (int i = 0; i < 5; ++i) {
169 for (int j = 0; j < 7; ++j) {
170 for (int k = 0; k < 3; ++k) {
171 for (int l = 0; l < 2; ++l) {
172 VERIFY_IS_EQUAL(shuffle2(i,j,k,l), shuffle(i,j,k,l));
173 }
174 }
175 }
176 }
177 }
178
179
180 template <int DataLayout>
test_shuffle_unshuffle()181 static void test_shuffle_unshuffle()
182 {
183 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
184 tensor.setRandom();
185
186 // Choose a random permutation.
187 array<ptrdiff_t, 4> shuffles;
188 for (int i = 0; i < 4; ++i) {
189 shuffles[i] = i;
190 }
191 array<ptrdiff_t, 4> shuffles_inverse;
192 for (int i = 0; i < 4; ++i) {
193 const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
194 shuffles_inverse[shuffles[index]] = i;
195 std::swap(shuffles[i], shuffles[index]);
196 }
197
198 Tensor<float, 4, DataLayout> shuffle;
199 shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
200
201 VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
202 VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
203 VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
204 VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
205
206 for (int i = 0; i < 2; ++i) {
207 for (int j = 0; j < 3; ++j) {
208 for (int k = 0; k < 5; ++k) {
209 for (int l = 0; l < 7; ++l) {
210 VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
211 }
212 }
213 }
214 }
215 }
216
217
test_cxx11_tensor_shuffling()218 void test_cxx11_tensor_shuffling()
219 {
220 CALL_SUBTEST(test_simple_shuffling<ColMajor>());
221 CALL_SUBTEST(test_simple_shuffling<RowMajor>());
222 CALL_SUBTEST(test_expr_shuffling<ColMajor>());
223 CALL_SUBTEST(test_expr_shuffling<RowMajor>());
224 CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
225 CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
226 CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
227 CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
228 }
229