1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15 using Eigen::RowMajor;
16
test_simple_lvalue_ref()17 static void test_simple_lvalue_ref()
18 {
19 Tensor<int, 1> input(6);
20 input.setRandom();
21
22 TensorRef<Tensor<int, 1>> ref3(input);
23 TensorRef<Tensor<int, 1>> ref4 = input;
24
25 VERIFY_IS_EQUAL(ref3.data(), input.data());
26 VERIFY_IS_EQUAL(ref4.data(), input.data());
27
28 for (int i = 0; i < 6; ++i) {
29 VERIFY_IS_EQUAL(ref3(i), input(i));
30 VERIFY_IS_EQUAL(ref4(i), input(i));
31 }
32
33 for (int i = 0; i < 6; ++i) {
34 ref3.coeffRef(i) = i;
35 }
36 for (int i = 0; i < 6; ++i) {
37 VERIFY_IS_EQUAL(input(i), i);
38 }
39 for (int i = 0; i < 6; ++i) {
40 ref4.coeffRef(i) = -i * 2;
41 }
42 for (int i = 0; i < 6; ++i) {
43 VERIFY_IS_EQUAL(input(i), -i*2);
44 }
45 }
46
47
test_simple_rvalue_ref()48 static void test_simple_rvalue_ref()
49 {
50 Tensor<int, 1> input1(6);
51 input1.setRandom();
52 Tensor<int, 1> input2(6);
53 input2.setRandom();
54
55 TensorRef<Tensor<int, 1>> ref3(input1 + input2);
56 TensorRef<Tensor<int, 1>> ref4 = input1 + input2;
57
58 VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data());
59 VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data());
60 VERIFY_IS_NOT_EQUAL(ref3.data(), input2.data());
61 VERIFY_IS_NOT_EQUAL(ref4.data(), input2.data());
62
63 for (int i = 0; i < 6; ++i) {
64 VERIFY_IS_EQUAL(ref3(i), input1(i) + input2(i));
65 VERIFY_IS_EQUAL(ref4(i), input1(i) + input2(i));
66 }
67 }
68
69
test_multiple_dims()70 static void test_multiple_dims()
71 {
72 Tensor<float, 3> input(3,5,7);
73 input.setRandom();
74
75 TensorRef<Tensor<float, 3>> ref(input);
76 VERIFY_IS_EQUAL(ref.data(), input.data());
77 VERIFY_IS_EQUAL(ref.dimension(0), 3);
78 VERIFY_IS_EQUAL(ref.dimension(1), 5);
79 VERIFY_IS_EQUAL(ref.dimension(2), 7);
80
81 for (int i = 0; i < 3; ++i) {
82 for (int j = 0; j < 5; ++j) {
83 for (int k = 0; k < 7; ++k) {
84 VERIFY_IS_EQUAL(ref(i,j,k), input(i,j,k));
85 }
86 }
87 }
88 }
89
90
test_slice()91 static void test_slice()
92 {
93 Tensor<float, 5> tensor(2,3,5,7,11);
94 tensor.setRandom();
95
96 Eigen::DSizes<ptrdiff_t, 5> indices(1,2,3,4,5);
97 Eigen::DSizes<ptrdiff_t, 5> sizes(1,1,1,1,1);
98 TensorRef<Tensor<float, 5>> slice = tensor.slice(indices, sizes);
99 VERIFY_IS_EQUAL(slice(0,0,0,0,0), tensor(1,2,3,4,5));
100
101 Eigen::DSizes<ptrdiff_t, 5> indices2(1,1,3,4,5);
102 Eigen::DSizes<ptrdiff_t, 5> sizes2(1,1,2,2,3);
103 slice = tensor.slice(indices2, sizes2);
104 for (int i = 0; i < 2; ++i) {
105 for (int j = 0; j < 2; ++j) {
106 for (int k = 0; k < 3; ++k) {
107 VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k));
108 }
109 }
110 }
111
112 Eigen::DSizes<ptrdiff_t, 5> indices3(0,0,0,0,0);
113 Eigen::DSizes<ptrdiff_t, 5> sizes3(2,3,1,1,1);
114 slice = tensor.slice(indices3, sizes3);
115 VERIFY_IS_EQUAL(slice.data(), tensor.data());
116 }
117
118
test_ref_of_ref()119 static void test_ref_of_ref()
120 {
121 Tensor<float, 3> input(3,5,7);
122 input.setRandom();
123
124 TensorRef<Tensor<float, 3>> ref(input);
125 TensorRef<Tensor<float, 3>> ref_of_ref(ref);
126 TensorRef<Tensor<float, 3>> ref_of_ref2;
127 ref_of_ref2 = ref;
128
129 VERIFY_IS_EQUAL(ref_of_ref.data(), input.data());
130 VERIFY_IS_EQUAL(ref_of_ref.dimension(0), 3);
131 VERIFY_IS_EQUAL(ref_of_ref.dimension(1), 5);
132 VERIFY_IS_EQUAL(ref_of_ref.dimension(2), 7);
133
134 VERIFY_IS_EQUAL(ref_of_ref2.data(), input.data());
135 VERIFY_IS_EQUAL(ref_of_ref2.dimension(0), 3);
136 VERIFY_IS_EQUAL(ref_of_ref2.dimension(1), 5);
137 VERIFY_IS_EQUAL(ref_of_ref2.dimension(2), 7);
138
139 for (int i = 0; i < 3; ++i) {
140 for (int j = 0; j < 5; ++j) {
141 for (int k = 0; k < 7; ++k) {
142 VERIFY_IS_EQUAL(ref_of_ref(i,j,k), input(i,j,k));
143 VERIFY_IS_EQUAL(ref_of_ref2(i,j,k), input(i,j,k));
144 }
145 }
146 }
147 }
148
149
test_ref_in_expr()150 static void test_ref_in_expr()
151 {
152 Tensor<float, 3> input(3,5,7);
153 input.setRandom();
154 TensorRef<Tensor<float, 3>> input_ref(input);
155
156 Tensor<float, 3> result(3,5,7);
157 result.setRandom();
158 TensorRef<Tensor<float, 3>> result_ref(result);
159
160 Tensor<float, 3> bias(3,5,7);
161 bias.setRandom();
162
163 result_ref = input_ref + bias;
164 for (int i = 0; i < 3; ++i) {
165 for (int j = 0; j < 5; ++j) {
166 for (int k = 0; k < 7; ++k) {
167 VERIFY_IS_EQUAL(result_ref(i,j,k), input(i,j,k) + bias(i,j,k));
168 VERIFY_IS_NOT_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
169 }
170 }
171 }
172
173 result = result_ref;
174 for (int i = 0; i < 3; ++i) {
175 for (int j = 0; j < 5; ++j) {
176 for (int k = 0; k < 7; ++k) {
177 VERIFY_IS_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
178 }
179 }
180 }
181 }
182
183
test_coeff_ref()184 static void test_coeff_ref()
185 {
186 Tensor<float, 5> tensor(2,3,5,7,11);
187 tensor.setRandom();
188 Tensor<float, 5> original = tensor;
189
190 TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
191 slice.coeffRef(0, 0, 0, 0) = 1.0f;
192 slice.coeffRef(1, 0, 0, 0) += 2.0f;
193
194 VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
195 VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
196 }
197
198
test_nested_ops_with_ref()199 static void test_nested_ops_with_ref()
200 {
201 Tensor<float, 4> t(2, 3, 5, 7);
202 t.setRandom();
203 TensorMap<Tensor<const float, 4> > m(t.data(), 2, 3, 5, 7);
204 array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings;
205 paddings[0] = std::make_pair(0, 0);
206 paddings[1] = std::make_pair(2, 1);
207 paddings[2] = std::make_pair(3, 4);
208 paddings[3] = std::make_pair(0, 0);
209 DSizes<Eigen::DenseIndex, 4> shuffle_dims(0, 1, 2, 3);
210 TensorRef<Tensor<const float, 4> > ref(m.pad(paddings));
211 array<std::pair<ptrdiff_t, ptrdiff_t>, 4> trivial;
212 trivial[0] = std::make_pair(0, 0);
213 trivial[1] = std::make_pair(0, 0);
214 trivial[2] = std::make_pair(0, 0);
215 trivial[3] = std::make_pair(0, 0);
216 Tensor<float, 4> padded = ref.shuffle(shuffle_dims).pad(trivial);
217 VERIFY_IS_EQUAL(padded.dimension(0), 2+0);
218 VERIFY_IS_EQUAL(padded.dimension(1), 3+3);
219 VERIFY_IS_EQUAL(padded.dimension(2), 5+7);
220 VERIFY_IS_EQUAL(padded.dimension(3), 7+0);
221
222 for (int i = 0; i < 2; ++i) {
223 for (int j = 0; j < 6; ++j) {
224 for (int k = 0; k < 12; ++k) {
225 for (int l = 0; l < 7; ++l) {
226 if (j >= 2 && j < 5 && k >= 3 && k < 8) {
227 VERIFY_IS_EQUAL(padded(i,j,k,l), t(i,j-2,k-3,l));
228 } else {
229 VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f);
230 }
231 }
232 }
233 }
234 }
235 }
236
237
test_cxx11_tensor_ref()238 void test_cxx11_tensor_ref()
239 {
240 CALL_SUBTEST(test_simple_lvalue_ref());
241 CALL_SUBTEST(test_simple_rvalue_ref());
242 CALL_SUBTEST(test_multiple_dims());
243 CALL_SUBTEST(test_slice());
244 CALL_SUBTEST(test_ref_of_ref());
245 CALL_SUBTEST(test_ref_in_expr());
246 CALL_SUBTEST(test_coeff_ref());
247 CALL_SUBTEST(test_nested_ops_with_ref());
248 }
249