• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
test_dimension_failures()17 static void test_dimension_failures()
18 {
19   Tensor<int, 3, DataLayout> left(2, 3, 1);
20   Tensor<int, 3, DataLayout> right(3, 3, 1);
21   left.setRandom();
22   right.setRandom();
23 
24   // Okay; other dimensions are equal.
25   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26 
27   // Dimension mismatches.
28   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30 
31   // Axis > NumDims or < 0.
32   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34 }
35 
36 template<int DataLayout>
test_static_dimension_failure()37 static void test_static_dimension_failure()
38 {
39   Tensor<int, 2, DataLayout> left(2, 3);
40   Tensor<int, 3, DataLayout> right(2, 3, 1);
41 
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43   // Technically compatible, but we static assert that the inputs have same
44   // NumDims.
45   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47 
48   // This can be worked around in this case.
49   Tensor<int, 3, DataLayout> concatenation = left
50       .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51       .concatenate(right, 0);
52   Tensor<int, 2, DataLayout> alternative = left
53       .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0);
54 }
55 
56 template<int DataLayout>
test_simple_concatenation()57 static void test_simple_concatenation()
58 {
59   Tensor<int, 3, DataLayout> left(2, 3, 1);
60   Tensor<int, 3, DataLayout> right(2, 3, 1);
61   left.setRandom();
62   right.setRandom();
63 
64   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
65   VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
66   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
67   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
68   for (int j = 0; j < 3; ++j) {
69     for (int i = 0; i < 2; ++i) {
70       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
71     }
72     for (int i = 2; i < 4; ++i) {
73       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
74     }
75   }
76 
77   concatenation = left.concatenate(right, 1);
78   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
79   VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
80   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
81   for (int i = 0; i < 2; ++i) {
82     for (int j = 0; j < 3; ++j) {
83       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
84     }
85     for (int j = 3; j < 6; ++j) {
86       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
87     }
88   }
89 
90   concatenation = left.concatenate(right, 2);
91   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
92   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
93   VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
94   for (int i = 0; i < 2; ++i) {
95     for (int j = 0; j < 3; ++j) {
96       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
97       VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
98     }
99   }
100 }
101 
102 
103 // TODO(phli): Add test once we have a real vectorized implementation.
104 // static void test_vectorized_concatenation() {}
105 
test_concatenation_as_lvalue()106 static void test_concatenation_as_lvalue()
107 {
108   Tensor<int, 2> t1(2, 3);
109   Tensor<int, 2> t2(2, 3);
110   t1.setRandom();
111   t2.setRandom();
112 
113   Tensor<int, 2> result(4, 3);
114   result.setRandom();
115   t1.concatenate(t2, 0) = result;
116 
117   for (int i = 0; i < 2; ++i) {
118     for (int j = 0; j < 3; ++j) {
119       VERIFY_IS_EQUAL(t1(i, j), result(i, j));
120       VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
121     }
122   }
123 }
124 
125 
test_cxx11_tensor_concatenation()126 void test_cxx11_tensor_concatenation()
127 {
128    CALL_SUBTEST(test_dimension_failures<ColMajor>());
129    CALL_SUBTEST(test_dimension_failures<RowMajor>());
130    CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
131    CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
132    CALL_SUBTEST(test_simple_concatenation<ColMajor>());
133    CALL_SUBTEST(test_simple_concatenation<RowMajor>());
134    // CALL_SUBTEST(test_vectorized_concatenation());
135    CALL_SUBTEST(test_concatenation_as_lvalue());
136 
137 }
138