• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@google.com>
5 //                    Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #include "main.h"
12 
13 #include <Eigen/CXX11/Tensor>
14 
15 using Eigen::Tensor;
16 using Eigen::array;
17 using Eigen::Tuple;
18 
19 template <int DataLayout>
test_simple_index_tuples()20 static void test_simple_index_tuples()
21 {
22   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
23   tensor.setRandom();
24   tensor = (tensor + tensor.constant(0.5)).log();
25 
26   Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
27   index_tuples = tensor.index_tuples();
28 
29   for (DenseIndex n = 0; n < 2*3*5*7; ++n) {
30     const Tuple<DenseIndex, float>& v = index_tuples.coeff(n);
31     VERIFY_IS_EQUAL(v.first, n);
32     VERIFY_IS_EQUAL(v.second, tensor.coeff(n));
33   }
34 }
35 
36 template <int DataLayout>
test_index_tuples_dim()37 static void test_index_tuples_dim()
38 {
39   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
40   tensor.setRandom();
41   tensor = (tensor + tensor.constant(0.5)).log();
42 
43   Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
44 
45   index_tuples = tensor.index_tuples();
46 
47   for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) {
48     const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l);
49     VERIFY_IS_EQUAL(v.first, n);
50     VERIFY_IS_EQUAL(v.second, tensor(n));
51   }
52 }
53 
54 template <int DataLayout>
test_argmax_tuple_reducer()55 static void test_argmax_tuple_reducer()
56 {
57   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
58   tensor.setRandom();
59   tensor = (tensor + tensor.constant(0.5)).log();
60 
61   Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
62   index_tuples = tensor.index_tuples();
63 
64   Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
65   DimensionList<DenseIndex, 4> dims;
66   reduced = index_tuples.reduce(
67       dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
68 
69   Tensor<float, 0, DataLayout> maxi = tensor.maximum();
70 
71   VERIFY_IS_EQUAL(maxi(), reduced(0).second);
72 
73   array<DenseIndex, 3> reduce_dims;
74   for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
75   Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
76   reduced_by_dims = index_tuples.reduce(
77       reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
78 
79   Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims);
80 
81   for (int l = 0; l < 7; ++l) {
82     VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second);
83   }
84 }
85 
86 template <int DataLayout>
test_argmin_tuple_reducer()87 static void test_argmin_tuple_reducer()
88 {
89   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
90   tensor.setRandom();
91   tensor = (tensor + tensor.constant(0.5)).log();
92 
93   Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
94   index_tuples = tensor.index_tuples();
95 
96   Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
97   DimensionList<DenseIndex, 4> dims;
98   reduced = index_tuples.reduce(
99       dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
100 
101   Tensor<float, 0, DataLayout> mini = tensor.minimum();
102 
103   VERIFY_IS_EQUAL(mini(), reduced(0).second);
104 
105   array<DenseIndex, 3> reduce_dims;
106   for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
107   Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
108   reduced_by_dims = index_tuples.reduce(
109       reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
110 
111   Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims);
112 
113   for (int l = 0; l < 7; ++l) {
114     VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second);
115   }
116 }
117 
118 template <int DataLayout>
test_simple_argmax()119 static void test_simple_argmax()
120 {
121   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
122   tensor.setRandom();
123   tensor = (tensor + tensor.constant(0.5)).log();
124   tensor(0,0,0,0) = 10.0;
125 
126   Tensor<DenseIndex, 0, DataLayout> tensor_argmax;
127 
128   tensor_argmax = tensor.argmax();
129 
130   VERIFY_IS_EQUAL(tensor_argmax(0), 0);
131 
132   tensor(1,2,4,6) = 20.0;
133 
134   tensor_argmax = tensor.argmax();
135 
136   VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1);
137 }
138 
139 template <int DataLayout>
test_simple_argmin()140 static void test_simple_argmin()
141 {
142   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
143   tensor.setRandom();
144   tensor = (tensor + tensor.constant(0.5)).log();
145   tensor(0,0,0,0) = -10.0;
146 
147   Tensor<DenseIndex, 0, DataLayout> tensor_argmin;
148 
149   tensor_argmin = tensor.argmin();
150 
151   VERIFY_IS_EQUAL(tensor_argmin(0), 0);
152 
153   tensor(1,2,4,6) = -20.0;
154 
155   tensor_argmin = tensor.argmin();
156 
157   VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1);
158 }
159 
160 template <int DataLayout>
test_argmax_dim()161 static void test_argmax_dim()
162 {
163   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
164   std::vector<int> dims {2, 3, 5, 7};
165 
166   for (int dim = 0; dim < 4; ++dim) {
167     tensor.setRandom();
168     tensor = (tensor + tensor.constant(0.5)).log();
169 
170     Tensor<DenseIndex, 3, DataLayout> tensor_argmax;
171     array<DenseIndex, 4> ix;
172     for (int i = 0; i < 2; ++i) {
173       for (int j = 0; j < 3; ++j) {
174         for (int k = 0; k < 5; ++k) {
175           for (int l = 0; l < 7; ++l) {
176             ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
177             if (ix[dim] != 0) continue;
178             // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = 10.0
179             tensor(ix) = 10.0;
180           }
181         }
182       }
183     }
184 
185     tensor_argmax = tensor.argmax(dim);
186 
187     VERIFY_IS_EQUAL(tensor_argmax.size(),
188                     ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
189     for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
190       // Expect max to be in the first index of the reduced dimension
191       VERIFY_IS_EQUAL(tensor_argmax.data()[n], 0);
192     }
193 
194     for (int i = 0; i < 2; ++i) {
195       for (int j = 0; j < 3; ++j) {
196         for (int k = 0; k < 5; ++k) {
197           for (int l = 0; l < 7; ++l) {
198             ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
199             if (ix[dim] != tensor.dimension(dim) - 1) continue;
200             // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = 20.0
201             tensor(ix) = 20.0;
202           }
203         }
204       }
205     }
206 
207     tensor_argmax = tensor.argmax(dim);
208 
209     VERIFY_IS_EQUAL(tensor_argmax.size(),
210                     ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
211     for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
212       // Expect max to be in the last index of the reduced dimension
213       VERIFY_IS_EQUAL(tensor_argmax.data()[n], tensor.dimension(dim) - 1);
214     }
215   }
216 }
217 
218 template <int DataLayout>
test_argmin_dim()219 static void test_argmin_dim()
220 {
221   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
222   std::vector<int> dims {2, 3, 5, 7};
223 
224   for (int dim = 0; dim < 4; ++dim) {
225     tensor.setRandom();
226     tensor = (tensor + tensor.constant(0.5)).log();
227 
228     Tensor<DenseIndex, 3, DataLayout> tensor_argmin;
229     array<DenseIndex, 4> ix;
230     for (int i = 0; i < 2; ++i) {
231       for (int j = 0; j < 3; ++j) {
232         for (int k = 0; k < 5; ++k) {
233           for (int l = 0; l < 7; ++l) {
234             ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
235             if (ix[dim] != 0) continue;
236             // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = -10.0
237             tensor(ix) = -10.0;
238           }
239         }
240       }
241     }
242 
243     tensor_argmin = tensor.argmin(dim);
244 
245     VERIFY_IS_EQUAL(tensor_argmin.size(),
246                     ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
247     for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
248       // Expect min to be in the first index of the reduced dimension
249       VERIFY_IS_EQUAL(tensor_argmin.data()[n], 0);
250     }
251 
252     for (int i = 0; i < 2; ++i) {
253       for (int j = 0; j < 3; ++j) {
254         for (int k = 0; k < 5; ++k) {
255           for (int l = 0; l < 7; ++l) {
256             ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
257             if (ix[dim] != tensor.dimension(dim) - 1) continue;
258             // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = -20.0
259             tensor(ix) = -20.0;
260           }
261         }
262       }
263     }
264 
265     tensor_argmin = tensor.argmin(dim);
266 
267     VERIFY_IS_EQUAL(tensor_argmin.size(),
268                     ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
269     for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
270       // Expect min to be in the last index of the reduced dimension
271       VERIFY_IS_EQUAL(tensor_argmin.data()[n], tensor.dimension(dim) - 1);
272     }
273   }
274 }
275 
test_cxx11_tensor_argmax()276 void test_cxx11_tensor_argmax()
277 {
278   CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
279   CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
280   CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
281   CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
282   CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
283   CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
284   CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
285   CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
286   CALL_SUBTEST(test_simple_argmax<RowMajor>());
287   CALL_SUBTEST(test_simple_argmax<ColMajor>());
288   CALL_SUBTEST(test_simple_argmin<RowMajor>());
289   CALL_SUBTEST(test_simple_argmin<ColMajor>());
290   CALL_SUBTEST(test_argmax_dim<RowMajor>());
291   CALL_SUBTEST(test_argmax_dim<ColMajor>());
292   CALL_SUBTEST(test_argmin_dim<RowMajor>());
293   CALL_SUBTEST(test_argmin_dim<ColMajor>());
294 }
295