1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template <int DataLayout>
test_simple_broadcasting()17 static void test_simple_broadcasting()
18 {
19 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20 tensor.setRandom();
21 array<ptrdiff_t, 4> broadcasts;
22 broadcasts[0] = 1;
23 broadcasts[1] = 1;
24 broadcasts[2] = 1;
25 broadcasts[3] = 1;
26
27 Tensor<float, 4, DataLayout> no_broadcast;
28 no_broadcast = tensor.broadcast(broadcasts);
29
30 VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
31 VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
32 VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
33 VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
34
35 for (int i = 0; i < 2; ++i) {
36 for (int j = 0; j < 3; ++j) {
37 for (int k = 0; k < 5; ++k) {
38 for (int l = 0; l < 7; ++l) {
39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
40 }
41 }
42 }
43 }
44
45 broadcasts[0] = 2;
46 broadcasts[1] = 3;
47 broadcasts[2] = 1;
48 broadcasts[3] = 4;
49 Tensor<float, 4, DataLayout> broadcast;
50 broadcast = tensor.broadcast(broadcasts);
51
52 VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
53 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
54 VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
55 VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
56
57 for (int i = 0; i < 4; ++i) {
58 for (int j = 0; j < 9; ++j) {
59 for (int k = 0; k < 5; ++k) {
60 for (int l = 0; l < 28; ++l) {
61 VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
62 }
63 }
64 }
65 }
66 }
67
68
69 template <int DataLayout>
test_vectorized_broadcasting()70 static void test_vectorized_broadcasting()
71 {
72 Tensor<float, 3, DataLayout> tensor(8,3,5);
73 tensor.setRandom();
74 array<ptrdiff_t, 3> broadcasts;
75 broadcasts[0] = 2;
76 broadcasts[1] = 3;
77 broadcasts[2] = 4;
78
79 Tensor<float, 3, DataLayout> broadcast;
80 broadcast = tensor.broadcast(broadcasts);
81
82 VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
83 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
84 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
85
86 for (int i = 0; i < 16; ++i) {
87 for (int j = 0; j < 9; ++j) {
88 for (int k = 0; k < 20; ++k) {
89 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
90 }
91 }
92 }
93
94 tensor.resize(11,3,5);
95 tensor.setRandom();
96 broadcast = tensor.broadcast(broadcasts);
97
98 VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
99 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
100 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
101
102 for (int i = 0; i < 22; ++i) {
103 for (int j = 0; j < 9; ++j) {
104 for (int k = 0; k < 20; ++k) {
105 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
106 }
107 }
108 }
109 }
110
111
112 template <int DataLayout>
test_static_broadcasting()113 static void test_static_broadcasting()
114 {
115 Tensor<float, 3, DataLayout> tensor(8,3,5);
116 tensor.setRandom();
117
118 #if EIGEN_HAS_CONSTEXPR
119 Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
120 #else
121 Eigen::array<int, 3> broadcasts;
122 broadcasts[0] = 2;
123 broadcasts[1] = 3;
124 broadcasts[2] = 4;
125 #endif
126
127 Tensor<float, 3, DataLayout> broadcast;
128 broadcast = tensor.broadcast(broadcasts);
129
130 VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
131 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
132 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
133
134 for (int i = 0; i < 16; ++i) {
135 for (int j = 0; j < 9; ++j) {
136 for (int k = 0; k < 20; ++k) {
137 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
138 }
139 }
140 }
141
142 tensor.resize(11,3,5);
143 tensor.setRandom();
144 broadcast = tensor.broadcast(broadcasts);
145
146 VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
147 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
148 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
149
150 for (int i = 0; i < 22; ++i) {
151 for (int j = 0; j < 9; ++j) {
152 for (int k = 0; k < 20; ++k) {
153 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
154 }
155 }
156 }
157 }
158
159
160 template <int DataLayout>
test_fixed_size_broadcasting()161 static void test_fixed_size_broadcasting()
162 {
163 // Need to add a [] operator to the Size class for this to work
164 #if 0
165 Tensor<float, 1, DataLayout> t1(10);
166 t1.setRandom();
167 TensorFixedSize<float, Sizes<1>, DataLayout> t2;
168 t2 = t2.constant(20.0f);
169
170 Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
171 for (int i = 0; i < 10; ++i) {
172 VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
173 }
174
175 TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
176 Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
177 for (int i = 0; i < 10; ++i) {
178 VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
179 }
180 #endif
181 }
182
183
test_cxx11_tensor_broadcasting()184 void test_cxx11_tensor_broadcasting()
185 {
186 CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
187 CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
188 CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
189 CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
190 CALL_SUBTEST(test_static_broadcasting<ColMajor>());
191 CALL_SUBTEST(test_static_broadcasting<RowMajor>());
192 CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
193 CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
194 }
195