1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@google.com>
5 // Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11 #include "main.h"
12
13 #include <Eigen/CXX11/Tensor>
14
15 using Eigen::Tensor;
16 using Eigen::array;
17 using Eigen::Tuple;
18
19 template <int DataLayout>
test_simple_index_tuples()20 static void test_simple_index_tuples()
21 {
22 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
23 tensor.setRandom();
24 tensor = (tensor + tensor.constant(0.5)).log();
25
26 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
27 index_tuples = tensor.index_tuples();
28
29 for (DenseIndex n = 0; n < 2*3*5*7; ++n) {
30 const Tuple<DenseIndex, float>& v = index_tuples.coeff(n);
31 VERIFY_IS_EQUAL(v.first, n);
32 VERIFY_IS_EQUAL(v.second, tensor.coeff(n));
33 }
34 }
35
36 template <int DataLayout>
test_index_tuples_dim()37 static void test_index_tuples_dim()
38 {
39 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
40 tensor.setRandom();
41 tensor = (tensor + tensor.constant(0.5)).log();
42
43 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
44
45 index_tuples = tensor.index_tuples();
46
47 for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) {
48 const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l);
49 VERIFY_IS_EQUAL(v.first, n);
50 VERIFY_IS_EQUAL(v.second, tensor(n));
51 }
52 }
53
54 template <int DataLayout>
test_argmax_tuple_reducer()55 static void test_argmax_tuple_reducer()
56 {
57 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
58 tensor.setRandom();
59 tensor = (tensor + tensor.constant(0.5)).log();
60
61 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
62 index_tuples = tensor.index_tuples();
63
64 Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
65 DimensionList<DenseIndex, 4> dims;
66 reduced = index_tuples.reduce(
67 dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
68
69 Tensor<float, 0, DataLayout> maxi = tensor.maximum();
70
71 VERIFY_IS_EQUAL(maxi(), reduced(0).second);
72
73 array<DenseIndex, 3> reduce_dims;
74 for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
75 Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
76 reduced_by_dims = index_tuples.reduce(
77 reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
78
79 Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims);
80
81 for (int l = 0; l < 7; ++l) {
82 VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second);
83 }
84 }
85
86 template <int DataLayout>
test_argmin_tuple_reducer()87 static void test_argmin_tuple_reducer()
88 {
89 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
90 tensor.setRandom();
91 tensor = (tensor + tensor.constant(0.5)).log();
92
93 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
94 index_tuples = tensor.index_tuples();
95
96 Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
97 DimensionList<DenseIndex, 4> dims;
98 reduced = index_tuples.reduce(
99 dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
100
101 Tensor<float, 0, DataLayout> mini = tensor.minimum();
102
103 VERIFY_IS_EQUAL(mini(), reduced(0).second);
104
105 array<DenseIndex, 3> reduce_dims;
106 for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
107 Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
108 reduced_by_dims = index_tuples.reduce(
109 reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
110
111 Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims);
112
113 for (int l = 0; l < 7; ++l) {
114 VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second);
115 }
116 }
117
118 template <int DataLayout>
test_simple_argmax()119 static void test_simple_argmax()
120 {
121 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
122 tensor.setRandom();
123 tensor = (tensor + tensor.constant(0.5)).log();
124 tensor(0,0,0,0) = 10.0;
125
126 Tensor<DenseIndex, 0, DataLayout> tensor_argmax;
127
128 tensor_argmax = tensor.argmax();
129
130 VERIFY_IS_EQUAL(tensor_argmax(0), 0);
131
132 tensor(1,2,4,6) = 20.0;
133
134 tensor_argmax = tensor.argmax();
135
136 VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1);
137 }
138
139 template <int DataLayout>
test_simple_argmin()140 static void test_simple_argmin()
141 {
142 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
143 tensor.setRandom();
144 tensor = (tensor + tensor.constant(0.5)).log();
145 tensor(0,0,0,0) = -10.0;
146
147 Tensor<DenseIndex, 0, DataLayout> tensor_argmin;
148
149 tensor_argmin = tensor.argmin();
150
151 VERIFY_IS_EQUAL(tensor_argmin(0), 0);
152
153 tensor(1,2,4,6) = -20.0;
154
155 tensor_argmin = tensor.argmin();
156
157 VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1);
158 }
159
160 template <int DataLayout>
test_argmax_dim()161 static void test_argmax_dim()
162 {
163 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
164 std::vector<int> dims {2, 3, 5, 7};
165
166 for (int dim = 0; dim < 4; ++dim) {
167 tensor.setRandom();
168 tensor = (tensor + tensor.constant(0.5)).log();
169
170 Tensor<DenseIndex, 3, DataLayout> tensor_argmax;
171 array<DenseIndex, 4> ix;
172 for (int i = 0; i < 2; ++i) {
173 for (int j = 0; j < 3; ++j) {
174 for (int k = 0; k < 5; ++k) {
175 for (int l = 0; l < 7; ++l) {
176 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
177 if (ix[dim] != 0) continue;
178 // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = 10.0
179 tensor(ix) = 10.0;
180 }
181 }
182 }
183 }
184
185 tensor_argmax = tensor.argmax(dim);
186
187 VERIFY_IS_EQUAL(tensor_argmax.size(),
188 ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
189 for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
190 // Expect max to be in the first index of the reduced dimension
191 VERIFY_IS_EQUAL(tensor_argmax.data()[n], 0);
192 }
193
194 for (int i = 0; i < 2; ++i) {
195 for (int j = 0; j < 3; ++j) {
196 for (int k = 0; k < 5; ++k) {
197 for (int l = 0; l < 7; ++l) {
198 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
199 if (ix[dim] != tensor.dimension(dim) - 1) continue;
200 // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = 20.0
201 tensor(ix) = 20.0;
202 }
203 }
204 }
205 }
206
207 tensor_argmax = tensor.argmax(dim);
208
209 VERIFY_IS_EQUAL(tensor_argmax.size(),
210 ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
211 for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
212 // Expect max to be in the last index of the reduced dimension
213 VERIFY_IS_EQUAL(tensor_argmax.data()[n], tensor.dimension(dim) - 1);
214 }
215 }
216 }
217
218 template <int DataLayout>
test_argmin_dim()219 static void test_argmin_dim()
220 {
221 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
222 std::vector<int> dims {2, 3, 5, 7};
223
224 for (int dim = 0; dim < 4; ++dim) {
225 tensor.setRandom();
226 tensor = (tensor + tensor.constant(0.5)).log();
227
228 Tensor<DenseIndex, 3, DataLayout> tensor_argmin;
229 array<DenseIndex, 4> ix;
230 for (int i = 0; i < 2; ++i) {
231 for (int j = 0; j < 3; ++j) {
232 for (int k = 0; k < 5; ++k) {
233 for (int l = 0; l < 7; ++l) {
234 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
235 if (ix[dim] != 0) continue;
236 // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = -10.0
237 tensor(ix) = -10.0;
238 }
239 }
240 }
241 }
242
243 tensor_argmin = tensor.argmin(dim);
244
245 VERIFY_IS_EQUAL(tensor_argmin.size(),
246 ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
247 for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
248 // Expect min to be in the first index of the reduced dimension
249 VERIFY_IS_EQUAL(tensor_argmin.data()[n], 0);
250 }
251
252 for (int i = 0; i < 2; ++i) {
253 for (int j = 0; j < 3; ++j) {
254 for (int k = 0; k < 5; ++k) {
255 for (int l = 0; l < 7; ++l) {
256 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
257 if (ix[dim] != tensor.dimension(dim) - 1) continue;
258 // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = -20.0
259 tensor(ix) = -20.0;
260 }
261 }
262 }
263 }
264
265 tensor_argmin = tensor.argmin(dim);
266
267 VERIFY_IS_EQUAL(tensor_argmin.size(),
268 ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
269 for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
270 // Expect min to be in the last index of the reduced dimension
271 VERIFY_IS_EQUAL(tensor_argmin.data()[n], tensor.dimension(dim) - 1);
272 }
273 }
274 }
275
test_cxx11_tensor_argmax()276 void test_cxx11_tensor_argmax()
277 {
278 CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
279 CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
280 CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
281 CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
282 CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
283 CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
284 CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
285 CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
286 CALL_SUBTEST(test_simple_argmax<RowMajor>());
287 CALL_SUBTEST(test_simple_argmax<ColMajor>());
288 CALL_SUBTEST(test_simple_argmin<RowMajor>());
289 CALL_SUBTEST(test_simple_argmin<ColMajor>());
290 CALL_SUBTEST(test_argmax_dim<RowMajor>());
291 CALL_SUBTEST(test_argmax_dim<ColMajor>());
292 CALL_SUBTEST(test_argmin_dim<RowMajor>());
293 CALL_SUBTEST(test_argmin_dim<ColMajor>());
294 }
295