• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
test_simple_chip()17 static void test_simple_chip()
18 {
19   Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
20   tensor.setRandom();
21 
22   Tensor<float, 4, DataLayout> chip1;
23   chip1 = tensor.template chip<0>(1);
24 
25   VERIFY_IS_EQUAL(chip1.dimension(0), 3);
26   VERIFY_IS_EQUAL(chip1.dimension(1), 5);
27   VERIFY_IS_EQUAL(chip1.dimension(2), 7);
28   VERIFY_IS_EQUAL(chip1.dimension(3), 11);
29 
30   for (int i = 0; i < 3; ++i) {
31     for (int j = 0; j < 5; ++j) {
32       for (int k = 0; k < 7; ++k) {
33         for (int l = 0; l < 11; ++l) {
34           VERIFY_IS_EQUAL(chip1(i,j,k,l), tensor(1,i,j,k,l));
35         }
36       }
37     }
38   }
39 
40   Tensor<float, 4, DataLayout> chip2 = tensor.template chip<1>(1);
41   VERIFY_IS_EQUAL(chip2.dimension(0), 2);
42   VERIFY_IS_EQUAL(chip2.dimension(1), 5);
43   VERIFY_IS_EQUAL(chip2.dimension(2), 7);
44   VERIFY_IS_EQUAL(chip2.dimension(3), 11);
45   for (int i = 0; i < 2; ++i) {
46     for (int j = 0; j < 3; ++j) {
47       for (int k = 0; k < 7; ++k) {
48         for (int l = 0; l < 11; ++l) {
49           VERIFY_IS_EQUAL(chip2(i,j,k,l), tensor(i,1,j,k,l));
50         }
51       }
52     }
53   }
54 
55   Tensor<float, 4, DataLayout> chip3 = tensor.template chip<2>(2);
56   VERIFY_IS_EQUAL(chip3.dimension(0), 2);
57   VERIFY_IS_EQUAL(chip3.dimension(1), 3);
58   VERIFY_IS_EQUAL(chip3.dimension(2), 7);
59   VERIFY_IS_EQUAL(chip3.dimension(3), 11);
60   for (int i = 0; i < 2; ++i) {
61     for (int j = 0; j < 3; ++j) {
62       for (int k = 0; k < 7; ++k) {
63         for (int l = 0; l < 11; ++l) {
64           VERIFY_IS_EQUAL(chip3(i,j,k,l), tensor(i,j,2,k,l));
65         }
66       }
67     }
68   }
69 
70   Tensor<float, 4, DataLayout> chip4(tensor.template chip<3>(5));
71   VERIFY_IS_EQUAL(chip4.dimension(0), 2);
72   VERIFY_IS_EQUAL(chip4.dimension(1), 3);
73   VERIFY_IS_EQUAL(chip4.dimension(2), 5);
74   VERIFY_IS_EQUAL(chip4.dimension(3), 11);
75   for (int i = 0; i < 2; ++i) {
76     for (int j = 0; j < 3; ++j) {
77       for (int k = 0; k < 5; ++k) {
78         for (int l = 0; l < 7; ++l) {
79           VERIFY_IS_EQUAL(chip4(i,j,k,l), tensor(i,j,k,5,l));
80         }
81       }
82     }
83   }
84 
85   Tensor<float, 4, DataLayout> chip5(tensor.template chip<4>(7));
86   VERIFY_IS_EQUAL(chip5.dimension(0), 2);
87   VERIFY_IS_EQUAL(chip5.dimension(1), 3);
88   VERIFY_IS_EQUAL(chip5.dimension(2), 5);
89   VERIFY_IS_EQUAL(chip5.dimension(3), 7);
90   for (int i = 0; i < 2; ++i) {
91     for (int j = 0; j < 3; ++j) {
92       for (int k = 0; k < 5; ++k) {
93         for (int l = 0; l < 7; ++l) {
94           VERIFY_IS_EQUAL(chip5(i,j,k,l), tensor(i,j,k,l,7));
95         }
96       }
97     }
98   }
99 }
100 
101 template<int DataLayout>
test_dynamic_chip()102 static void test_dynamic_chip()
103 {
104   Tensor<float, 5, DataLayout> tensor(2,3,5,7,11);
105   tensor.setRandom();
106 
107   Tensor<float, 4, DataLayout> chip1;
108   chip1 = tensor.chip(1, 0);
109   VERIFY_IS_EQUAL(chip1.dimension(0), 3);
110   VERIFY_IS_EQUAL(chip1.dimension(1), 5);
111   VERIFY_IS_EQUAL(chip1.dimension(2), 7);
112   VERIFY_IS_EQUAL(chip1.dimension(3), 11);
113   for (int i = 0; i < 3; ++i) {
114     for (int j = 0; j < 5; ++j) {
115       for (int k = 0; k < 7; ++k) {
116         for (int l = 0; l < 11; ++l) {
117           VERIFY_IS_EQUAL(chip1(i,j,k,l), tensor(1,i,j,k,l));
118         }
119       }
120     }
121   }
122 
123   Tensor<float, 4, DataLayout> chip2 = tensor.chip(1, 1);
124   VERIFY_IS_EQUAL(chip2.dimension(0), 2);
125   VERIFY_IS_EQUAL(chip2.dimension(1), 5);
126   VERIFY_IS_EQUAL(chip2.dimension(2), 7);
127   VERIFY_IS_EQUAL(chip2.dimension(3), 11);
128   for (int i = 0; i < 2; ++i) {
129     for (int j = 0; j < 3; ++j) {
130       for (int k = 0; k < 7; ++k) {
131         for (int l = 0; l < 11; ++l) {
132           VERIFY_IS_EQUAL(chip2(i,j,k,l), tensor(i,1,j,k,l));
133         }
134       }
135     }
136   }
137 
138   Tensor<float, 4, DataLayout> chip3 = tensor.chip(2, 2);
139   VERIFY_IS_EQUAL(chip3.dimension(0), 2);
140   VERIFY_IS_EQUAL(chip3.dimension(1), 3);
141   VERIFY_IS_EQUAL(chip3.dimension(2), 7);
142   VERIFY_IS_EQUAL(chip3.dimension(3), 11);
143   for (int i = 0; i < 2; ++i) {
144     for (int j = 0; j < 3; ++j) {
145       for (int k = 0; k < 7; ++k) {
146         for (int l = 0; l < 11; ++l) {
147           VERIFY_IS_EQUAL(chip3(i,j,k,l), tensor(i,j,2,k,l));
148         }
149       }
150     }
151   }
152 
153   Tensor<float, 4, DataLayout> chip4(tensor.chip(5, 3));
154   VERIFY_IS_EQUAL(chip4.dimension(0), 2);
155   VERIFY_IS_EQUAL(chip4.dimension(1), 3);
156   VERIFY_IS_EQUAL(chip4.dimension(2), 5);
157   VERIFY_IS_EQUAL(chip4.dimension(3), 11);
158   for (int i = 0; i < 2; ++i) {
159     for (int j = 0; j < 3; ++j) {
160       for (int k = 0; k < 5; ++k) {
161         for (int l = 0; l < 7; ++l) {
162           VERIFY_IS_EQUAL(chip4(i,j,k,l), tensor(i,j,k,5,l));
163         }
164       }
165     }
166   }
167 
168   Tensor<float, 4, DataLayout> chip5(tensor.chip(7, 4));
169   VERIFY_IS_EQUAL(chip5.dimension(0), 2);
170   VERIFY_IS_EQUAL(chip5.dimension(1), 3);
171   VERIFY_IS_EQUAL(chip5.dimension(2), 5);
172   VERIFY_IS_EQUAL(chip5.dimension(3), 7);
173   for (int i = 0; i < 2; ++i) {
174     for (int j = 0; j < 3; ++j) {
175       for (int k = 0; k < 5; ++k) {
176         for (int l = 0; l < 7; ++l) {
177           VERIFY_IS_EQUAL(chip5(i,j,k,l), tensor(i,j,k,l,7));
178         }
179       }
180     }
181   }
182 }
183 
184 template<int DataLayout>
test_chip_in_expr()185 static void test_chip_in_expr() {
186   Tensor<float, 5, DataLayout> input1(2,3,5,7,11);
187   input1.setRandom();
188   Tensor<float, 4, DataLayout> input2(3,5,7,11);
189   input2.setRandom();
190 
191   Tensor<float, 4, DataLayout> result = input1.template chip<0>(0) + input2;
192   for (int i = 0; i < 3; ++i) {
193     for (int j = 0; j < 5; ++j) {
194       for (int k = 0; k < 7; ++k) {
195         for (int l = 0; l < 11; ++l) {
196           float expected = input1(0,i,j,k,l) + input2(i,j,k,l);
197           VERIFY_IS_EQUAL(result(i,j,k,l), expected);
198         }
199       }
200     }
201   }
202 
203   Tensor<float, 3, DataLayout> input3(3,7,11);
204   input3.setRandom();
205   Tensor<float, 3, DataLayout> result2 = input1.template chip<0>(0).template chip<1>(2) + input3;
206   for (int i = 0; i < 3; ++i) {
207     for (int j = 0; j < 7; ++j) {
208       for (int k = 0; k < 11; ++k) {
209         float expected = input1(0,i,2,j,k) + input3(i,j,k);
210         VERIFY_IS_EQUAL(result2(i,j,k), expected);
211       }
212     }
213   }
214 }
215 
216 template<int DataLayout>
test_chip_as_lvalue()217 static void test_chip_as_lvalue()
218 {
219   Tensor<float, 5, DataLayout> input1(2,3,5,7,11);
220   input1.setRandom();
221 
222   Tensor<float, 4, DataLayout> input2(3,5,7,11);
223   input2.setRandom();
224   Tensor<float, 5, DataLayout> tensor = input1;
225   tensor.template chip<0>(1) = input2;
226   for (int i = 0; i < 2; ++i) {
227     for (int j = 0; j < 3; ++j) {
228       for (int k = 0; k < 5; ++k) {
229         for (int l = 0; l < 7; ++l) {
230           for (int m = 0; m < 11; ++m) {
231             if (i != 1) {
232               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
233             } else {
234               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input2(j,k,l,m));
235             }
236           }
237         }
238       }
239     }
240   }
241 
242   Tensor<float, 4, DataLayout> input3(2,5,7,11);
243   input3.setRandom();
244   tensor = input1;
245   tensor.template chip<1>(1) = input3;
246   for (int i = 0; i < 2; ++i) {
247     for (int j = 0; j < 3; ++j) {
248       for (int k = 0; k < 5; ++k) {
249         for (int l = 0; l < 7; ++l) {
250           for (int m = 0; m < 11; ++m) {
251             if (j != 1) {
252               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
253             } else {
254               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input3(i,k,l,m));
255             }
256           }
257         }
258       }
259     }
260   }
261 
262   Tensor<float, 4, DataLayout> input4(2,3,7,11);
263   input4.setRandom();
264   tensor = input1;
265   tensor.template chip<2>(3) = input4;
266   for (int i = 0; i < 2; ++i) {
267     for (int j = 0; j < 3; ++j) {
268       for (int k = 0; k < 5; ++k) {
269         for (int l = 0; l < 7; ++l) {
270           for (int m = 0; m < 11; ++m) {
271             if (k != 3) {
272               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
273             } else {
274               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input4(i,j,l,m));
275             }
276           }
277         }
278       }
279     }
280   }
281 
282   Tensor<float, 4, DataLayout> input5(2,3,5,11);
283   input5.setRandom();
284   tensor = input1;
285   tensor.template chip<3>(4) = input5;
286   for (int i = 0; i < 2; ++i) {
287     for (int j = 0; j < 3; ++j) {
288       for (int k = 0; k < 5; ++k) {
289         for (int l = 0; l < 7; ++l) {
290           for (int m = 0; m < 11; ++m) {
291             if (l != 4) {
292               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
293             } else {
294               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input5(i,j,k,m));
295             }
296           }
297         }
298       }
299     }
300   }
301 
302   Tensor<float, 4, DataLayout> input6(2,3,5,7);
303   input6.setRandom();
304   tensor = input1;
305   tensor.template chip<4>(5) = input6;
306   for (int i = 0; i < 2; ++i) {
307     for (int j = 0; j < 3; ++j) {
308       for (int k = 0; k < 5; ++k) {
309         for (int l = 0; l < 7; ++l) {
310           for (int m = 0; m < 11; ++m) {
311             if (m != 5) {
312               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
313             } else {
314               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input6(i,j,k,l));
315             }
316           }
317         }
318       }
319     }
320   }
321 
322   Tensor<float, 5, DataLayout> input7(2,3,5,7,11);
323   input7.setRandom();
324   tensor = input1;
325   tensor.chip(0, 0) = input7.chip(0, 0);
326   for (int i = 0; i < 2; ++i) {
327     for (int j = 0; j < 3; ++j) {
328       for (int k = 0; k < 5; ++k) {
329         for (int l = 0; l < 7; ++l) {
330           for (int m = 0; m < 11; ++m) {
331             if (i != 0) {
332               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input1(i,j,k,l,m));
333             } else {
334               VERIFY_IS_EQUAL(tensor(i,j,k,l,m), input7(i,j,k,l,m));
335             }
336           }
337         }
338       }
339     }
340   }
341 }
342 
test_chip_raw_data_col_major()343 static void test_chip_raw_data_col_major()
344 {
345   Tensor<float, 5, ColMajor> tensor(2,3,5,7,11);
346   tensor.setRandom();
347 
348   typedef TensorEvaluator<decltype(tensor.chip<4>(3)), DefaultDevice> Evaluator4;
349   auto chip = Evaluator4(tensor.chip<4>(3), DefaultDevice());
350   for (int i = 0; i < 2; ++i) {
351     for (int j = 0; j < 3; ++j) {
352       for (int k = 0; k < 5; ++k) {
353         for (int l = 0; l < 7; ++l) {
354           int chip_index = i + 2 * (j + 3 * (k + 5 * l));
355           VERIFY_IS_EQUAL(chip.data()[chip_index], tensor(i,j,k,l,3));
356         }
357       }
358     }
359   }
360 
361   typedef TensorEvaluator<decltype(tensor.chip<0>(0)), DefaultDevice> Evaluator0;
362   auto chip0 = Evaluator0(tensor.chip<0>(0), DefaultDevice());
363   VERIFY_IS_EQUAL(chip0.data(), static_cast<float*>(0));
364 
365   typedef TensorEvaluator<decltype(tensor.chip<1>(0)), DefaultDevice> Evaluator1;
366   auto chip1 = Evaluator1(tensor.chip<1>(0), DefaultDevice());
367   VERIFY_IS_EQUAL(chip1.data(), static_cast<float*>(0));
368 
369   typedef TensorEvaluator<decltype(tensor.chip<2>(0)), DefaultDevice> Evaluator2;
370   auto chip2 = Evaluator2(tensor.chip<2>(0), DefaultDevice());
371   VERIFY_IS_EQUAL(chip2.data(), static_cast<float*>(0));
372 
373   typedef TensorEvaluator<decltype(tensor.chip<3>(0)), DefaultDevice> Evaluator3;
374   auto chip3 = Evaluator3(tensor.chip<3>(0), DefaultDevice());
375   VERIFY_IS_EQUAL(chip3.data(), static_cast<float*>(0));
376 }
377 
test_chip_raw_data_row_major()378 static void test_chip_raw_data_row_major()
379 {
380   Tensor<float, 5, RowMajor> tensor(11,7,5,3,2);
381   tensor.setRandom();
382 
383   typedef TensorEvaluator<decltype(tensor.chip<0>(3)), DefaultDevice> Evaluator0;
384   auto chip = Evaluator0(tensor.chip<0>(3), DefaultDevice());
385   for (int i = 0; i < 7; ++i) {
386     for (int j = 0; j < 5; ++j) {
387       for (int k = 0; k < 3; ++k) {
388         for (int l = 0; l < 2; ++l) {
389           int chip_index = l + 2 * (k + 3 * (j + 5 * i));
390           VERIFY_IS_EQUAL(chip.data()[chip_index], tensor(3,i,j,k,l));
391         }
392       }
393     }
394   }
395 
396   typedef TensorEvaluator<decltype(tensor.chip<1>(0)), DefaultDevice> Evaluator1;
397   auto chip1 = Evaluator1(tensor.chip<1>(0), DefaultDevice());
398   VERIFY_IS_EQUAL(chip1.data(), static_cast<float*>(0));
399 
400   typedef TensorEvaluator<decltype(tensor.chip<2>(0)), DefaultDevice> Evaluator2;
401   auto chip2 = Evaluator2(tensor.chip<2>(0), DefaultDevice());
402   VERIFY_IS_EQUAL(chip2.data(), static_cast<float*>(0));
403 
404   typedef TensorEvaluator<decltype(tensor.chip<3>(0)), DefaultDevice> Evaluator3;
405   auto chip3 = Evaluator3(tensor.chip<3>(0), DefaultDevice());
406   VERIFY_IS_EQUAL(chip3.data(), static_cast<float*>(0));
407 
408   typedef TensorEvaluator<decltype(tensor.chip<4>(0)), DefaultDevice> Evaluator4;
409   auto chip4 = Evaluator4(tensor.chip<4>(0), DefaultDevice());
410   VERIFY_IS_EQUAL(chip4.data(), static_cast<float*>(0));
411 }
412 
test_cxx11_tensor_chipping()413 void test_cxx11_tensor_chipping()
414 {
415   CALL_SUBTEST(test_simple_chip<ColMajor>());
416   CALL_SUBTEST(test_simple_chip<RowMajor>());
417   CALL_SUBTEST(test_dynamic_chip<ColMajor>());
418   CALL_SUBTEST(test_dynamic_chip<RowMajor>());
419   CALL_SUBTEST(test_chip_in_expr<ColMajor>());
420   CALL_SUBTEST(test_chip_in_expr<RowMajor>());
421   CALL_SUBTEST(test_chip_as_lvalue<ColMajor>());
422   CALL_SUBTEST(test_chip_as_lvalue<RowMajor>());
423   CALL_SUBTEST(test_chip_raw_data_col_major());
424   CALL_SUBTEST(test_chip_raw_data_row_major());
425 }
426