1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template<int DataLayout>
test_dimension_failures()17 static void test_dimension_failures()
18 {
19 Tensor<int, 3, DataLayout> left(2, 3, 1);
20 Tensor<int, 3, DataLayout> right(3, 3, 1);
21 left.setRandom();
22 right.setRandom();
23
24 // Okay; other dimensions are equal.
25 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26
27 // Dimension mismatches.
28 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30
31 // Axis > NumDims or < 0.
32 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34 }
35
36 template<int DataLayout>
test_static_dimension_failure()37 static void test_static_dimension_failure()
38 {
39 Tensor<int, 2, DataLayout> left(2, 3);
40 Tensor<int, 3, DataLayout> right(2, 3, 1);
41
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43 // Technically compatible, but we static assert that the inputs have same
44 // NumDims.
45 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47
48 // This can be worked around in this case.
49 Tensor<int, 3, DataLayout> concatenation = left
50 .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51 .concatenate(right, 0);
52 Tensor<int, 2, DataLayout> alternative = left
53 .concatenate(right.reshape(Tensor<int, 2>::Dimensions{{{2, 3}}}), 0);
54 }
55
56 template<int DataLayout>
test_simple_concatenation()57 static void test_simple_concatenation()
58 {
59 Tensor<int, 3, DataLayout> left(2, 3, 1);
60 Tensor<int, 3, DataLayout> right(2, 3, 1);
61 left.setRandom();
62 right.setRandom();
63
64 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
65 VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
66 VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
67 VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
68 for (int j = 0; j < 3; ++j) {
69 for (int i = 0; i < 2; ++i) {
70 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
71 }
72 for (int i = 2; i < 4; ++i) {
73 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
74 }
75 }
76
77 concatenation = left.concatenate(right, 1);
78 VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
79 VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
80 VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
81 for (int i = 0; i < 2; ++i) {
82 for (int j = 0; j < 3; ++j) {
83 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
84 }
85 for (int j = 3; j < 6; ++j) {
86 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
87 }
88 }
89
90 concatenation = left.concatenate(right, 2);
91 VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
92 VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
93 VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
94 for (int i = 0; i < 2; ++i) {
95 for (int j = 0; j < 3; ++j) {
96 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
97 VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
98 }
99 }
100 }
101
102
103 // TODO(phli): Add test once we have a real vectorized implementation.
104 // static void test_vectorized_concatenation() {}
105
test_concatenation_as_lvalue()106 static void test_concatenation_as_lvalue()
107 {
108 Tensor<int, 2> t1(2, 3);
109 Tensor<int, 2> t2(2, 3);
110 t1.setRandom();
111 t2.setRandom();
112
113 Tensor<int, 2> result(4, 3);
114 result.setRandom();
115 t1.concatenate(t2, 0) = result;
116
117 for (int i = 0; i < 2; ++i) {
118 for (int j = 0; j < 3; ++j) {
119 VERIFY_IS_EQUAL(t1(i, j), result(i, j));
120 VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
121 }
122 }
123 }
124
125
test_cxx11_tensor_concatenation()126 void test_cxx11_tensor_concatenation()
127 {
128 CALL_SUBTEST(test_dimension_failures<ColMajor>());
129 CALL_SUBTEST(test_dimension_failures<RowMajor>());
130 CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
131 CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
132 CALL_SUBTEST(test_simple_concatenation<ColMajor>());
133 CALL_SUBTEST(test_simple_concatenation<RowMajor>());
134 // CALL_SUBTEST(test_vectorized_concatenation());
135 CALL_SUBTEST(test_concatenation_as_lvalue());
136
137 }
138