• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template <int DataLayout>
test_simple_broadcasting()17 static void test_simple_broadcasting()
18 {
19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20   tensor.setRandom();
21   array<ptrdiff_t, 4> broadcasts;
22   broadcasts[0] = 1;
23   broadcasts[1] = 1;
24   broadcasts[2] = 1;
25   broadcasts[3] = 1;
26 
27   Tensor<float, 4, DataLayout> no_broadcast;
28   no_broadcast = tensor.broadcast(broadcasts);
29 
30   VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
31   VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
32   VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
33   VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
34 
35   for (int i = 0; i < 2; ++i) {
36     for (int j = 0; j < 3; ++j) {
37       for (int k = 0; k < 5; ++k) {
38         for (int l = 0; l < 7; ++l) {
39           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
40         }
41       }
42     }
43   }
44 
45   broadcasts[0] = 2;
46   broadcasts[1] = 3;
47   broadcasts[2] = 1;
48   broadcasts[3] = 4;
49   Tensor<float, 4, DataLayout> broadcast;
50   broadcast = tensor.broadcast(broadcasts);
51 
52   VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
53   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
54   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
55   VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
56 
57   for (int i = 0; i < 4; ++i) {
58     for (int j = 0; j < 9; ++j) {
59       for (int k = 0; k < 5; ++k) {
60         for (int l = 0; l < 28; ++l) {
61           VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
62         }
63       }
64     }
65   }
66 }
67 
68 
69 template <int DataLayout>
test_vectorized_broadcasting()70 static void test_vectorized_broadcasting()
71 {
72   Tensor<float, 3, DataLayout> tensor(8,3,5);
73   tensor.setRandom();
74   array<ptrdiff_t, 3> broadcasts;
75   broadcasts[0] = 2;
76   broadcasts[1] = 3;
77   broadcasts[2] = 4;
78 
79   Tensor<float, 3, DataLayout> broadcast;
80   broadcast = tensor.broadcast(broadcasts);
81 
82   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
83   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
84   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
85 
86   for (int i = 0; i < 16; ++i) {
87     for (int j = 0; j < 9; ++j) {
88       for (int k = 0; k < 20; ++k) {
89         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
90       }
91     }
92   }
93 
94   tensor.resize(11,3,5);
95   tensor.setRandom();
96   broadcast = tensor.broadcast(broadcasts);
97 
98   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
99   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
100   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
101 
102   for (int i = 0; i < 22; ++i) {
103     for (int j = 0; j < 9; ++j) {
104       for (int k = 0; k < 20; ++k) {
105         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
106       }
107     }
108   }
109 }
110 
111 
112 template <int DataLayout>
test_static_broadcasting()113 static void test_static_broadcasting()
114 {
115   Tensor<float, 3, DataLayout> tensor(8,3,5);
116   tensor.setRandom();
117 
118 #if EIGEN_HAS_CONSTEXPR
119   Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
120 #else
121   Eigen::array<int, 3> broadcasts;
122   broadcasts[0] = 2;
123   broadcasts[1] = 3;
124   broadcasts[2] = 4;
125 #endif
126 
127   Tensor<float, 3, DataLayout> broadcast;
128   broadcast = tensor.broadcast(broadcasts);
129 
130   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
131   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
132   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
133 
134   for (int i = 0; i < 16; ++i) {
135     for (int j = 0; j < 9; ++j) {
136       for (int k = 0; k < 20; ++k) {
137         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
138       }
139     }
140   }
141 
142   tensor.resize(11,3,5);
143   tensor.setRandom();
144   broadcast = tensor.broadcast(broadcasts);
145 
146   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
147   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
148   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
149 
150   for (int i = 0; i < 22; ++i) {
151     for (int j = 0; j < 9; ++j) {
152       for (int k = 0; k < 20; ++k) {
153         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
154       }
155     }
156   }
157 }
158 
159 
160 template <int DataLayout>
test_fixed_size_broadcasting()161 static void test_fixed_size_broadcasting()
162 {
163   // Need to add a [] operator to the Size class for this to work
164 #if 0
165   Tensor<float, 1, DataLayout> t1(10);
166   t1.setRandom();
167   TensorFixedSize<float, Sizes<1>, DataLayout> t2;
168   t2 = t2.constant(20.0f);
169 
170   Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
171   for (int i = 0; i < 10; ++i) {
172     VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
173   }
174 
175   TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
176   Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
177   for (int i = 0; i < 10; ++i) {
178     VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
179   }
180 #endif
181 }
182 
183 
test_cxx11_tensor_broadcasting()184 void test_cxx11_tensor_broadcasting()
185 {
186   CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
187   CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
188   CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
189   CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
190   CALL_SUBTEST(test_static_broadcasting<ColMajor>());
191   CALL_SUBTEST(test_static_broadcasting<RowMajor>());
192   CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
193   CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
194 }
195