1 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
2
3 #include "test.h"
4
5 #include "../internal/fixedpoint.h"
6
7 using namespace gemmlowp;
8
9 template <int tIntegerBits>
test_convert(FixedPoint<int32_t,tIntegerBits> x)10 void test_convert(FixedPoint<int32_t, tIntegerBits> x) {
11 typedef FixedPoint<int32_t, tIntegerBits> F;
12 F y = ToFixedPoint<int32_t, tIntegerBits>(ToDouble(x));
13 Check(y == x);
14 }
15
16 template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(FixedPoint<int32_t,tIntegerBits_a> a)17 void test_Rescale(FixedPoint<int32_t, tIntegerBits_a> a) {
18 FixedPoint<int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
19 FixedPoint<int32_t, tIntegerBits_b> expected =
20 ToFixedPoint<int32_t, tIntegerBits_b>(ToDouble(a));
21 Check(actual == expected);
22 }
23
24 template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(const std::vector<int32_t> & testvals_int32)25 void test_Rescale(const std::vector<int32_t>& testvals_int32) {
26 for (auto a : testvals_int32) {
27 FixedPoint<int32_t, tIntegerBits_a> aq;
28 aq.raw() = a;
29 test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
30 }
31 }
32
33 template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(FixedPoint<int32_t,tIntegerBits_a> a,FixedPoint<int32_t,tIntegerBits_b> b)34 void test_mul(FixedPoint<int32_t, tIntegerBits_a> a,
35 FixedPoint<int32_t, tIntegerBits_b> b) {
36 static const int IntegerBits_ab = tIntegerBits_a + tIntegerBits_b;
37 FixedPoint<int32_t, IntegerBits_ab> ab;
38 ab = a * b;
39 double a_double = ToDouble(a);
40 double b_double = ToDouble(b);
41 double ab_double = a_double * b_double;
42 FixedPoint<int32_t, IntegerBits_ab> expected =
43 ToFixedPoint<int32_t, IntegerBits_ab>(ab_double);
44 int64_t diff = int64_t(ab.raw()) - int64_t(expected.raw());
45 Check(std::abs(diff) <= 1);
46 }
47
48 template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(const std::vector<int32_t> & testvals_int32)49 void test_mul(const std::vector<int32_t>& testvals_int32) {
50 for (auto a : testvals_int32) {
51 for (auto b : testvals_int32) {
52 FixedPoint<int32_t, tIntegerBits_a> aq;
53 FixedPoint<int32_t, tIntegerBits_b> bq;
54 aq.raw() = a;
55 bq.raw() = b;
56 test_mul(aq, bq);
57 }
58 }
59 }
60
61 template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(FixedPoint<int32_t,tIntegerBits_a> a)62 void test_ExactMulByPot(FixedPoint<int32_t, tIntegerBits_a> a) {
63 double x = ToDouble(a) * std::pow(2.0, tExponent);
64 double y = ToDouble(ExactMulByPot<tExponent>(a));
65 Check(x == y);
66 }
67
68 template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(const std::vector<int32_t> & testvals_int32)69 void test_ExactMulByPot(const std::vector<int32_t>& testvals_int32) {
70 for (auto a : testvals_int32) {
71 FixedPoint<int32_t, tIntegerBits_a> aq;
72 aq.raw() = a;
73 test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
74 }
75 }
76
test_exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<int32_t,0> a)77 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
78 FixedPoint<int32_t, 0> a) {
79 double a_double = ToDouble(a);
80 double expected = std::exp(a_double);
81 double actual =
82 ToDouble(exp_on_interval_between_negative_one_quarter_and_0_excl(a));
83 double error = expected - actual;
84 Check(std::abs(error) < 3e-7);
85 }
86
test_exp_on_interval_between_negative_one_quarter_and_0_excl(const std::vector<int32_t> & testvals_int32)87 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
88 const std::vector<int32_t>& testvals_int32) {
89 for (auto a : testvals_int32) {
90 typedef FixedPoint<int32_t, 0> F;
91 F aq = SaturatingRoundingMultiplyByPOT<-3>(F::FromRaw(a)) -
92 F::ConstantPOT<-3>();
93 test_exp_on_interval_between_negative_one_quarter_and_0_excl(aq);
94 }
95 }
96
97 template <int tIntegerBits>
test_exp_on_negative_values(FixedPoint<int32_t,tIntegerBits> a)98 void test_exp_on_negative_values(FixedPoint<int32_t, tIntegerBits> a) {
99 double a_double = ToDouble(a);
100 double expected = std::exp(a_double);
101 double actual = ToDouble(exp_on_negative_values(a));
102 double error = expected - actual;
103 Check(std::abs(error) < 3e-7);
104 }
105
106 template <int tIntegerBits>
test_exp_on_negative_values(const std::vector<int32_t> & testvals_int32)107 void test_exp_on_negative_values(const std::vector<int32_t>& testvals_int32) {
108 for (auto a : testvals_int32) {
109 if (a < 0) {
110 FixedPoint<int32_t, tIntegerBits> aq;
111 aq.raw() = a;
112 test_exp_on_negative_values(aq);
113 }
114 }
115 }
116
test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t,0> a)117 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t, 0> a) {
118 double a_double = ToDouble(a);
119 double expected = (1 - a_double) / (1 + a_double);
120 FixedPoint<int32_t, 0> retval = one_minus_x_over_one_plus_x_for_x_in_0_1(a);
121 double actual = ToDouble(retval);
122 double error = expected - actual;
123 Check(std::abs(error) < 6e-9);
124 }
125
test_one_minus_x_over_one_plus_x_for_x_in_0_1(const std::vector<int32_t> & testvals_int32)126 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(
127 const std::vector<int32_t>& testvals_int32) {
128 for (auto a : testvals_int32) {
129 if (a > 0) {
130 FixedPoint<int32_t, 0> aq;
131 aq.raw() = a;
132 test_one_minus_x_over_one_plus_x_for_x_in_0_1(aq);
133 }
134 }
135 }
136
137 template <int tIntegerBits>
test_tanh(FixedPoint<int32_t,tIntegerBits> a)138 void test_tanh(FixedPoint<int32_t, tIntegerBits> a) {
139 double a_double = ToDouble(a);
140 double expected = std::tanh(a_double);
141 double actual = ToDouble(tanh(a));
142 double error = expected - actual;
143 Check(std::abs(error) < 1.5e-7);
144 }
145
146 template <int tIntegerBits>
test_tanh(const std::vector<int32_t> & testvals_int32)147 void test_tanh(const std::vector<int32_t>& testvals_int32) {
148 for (auto a : testvals_int32) {
149 FixedPoint<int32_t, tIntegerBits> aq;
150 aq.raw() = a;
151 test_tanh(aq);
152 }
153 }
154
155 #ifdef GEMMLOWP_NEON
test_int32x4(const std::vector<int32_t> & testvals_int32)156 void test_int32x4(const std::vector<int32_t>& testvals_int32) {
157 size_t n = testvals_int32.size();
158 size_t n4 = n - (n % 4);
159 std::vector<int32_t> results_int32(n4);
160 std::vector<int32_t> results_int32x4(n4);
161
162 for (size_t i = 0; i < n4; i++) {
163 results_int32[i] =
164 tanh(FixedPoint<int32_t, 4>::FromRaw(testvals_int32[i])).raw();
165 }
166 for (size_t i = 0; i < n4; i++) {
167 vst1q_s32(
168 &results_int32x4[i],
169 tanh(FixedPoint<int32x4_t, 4>::FromRaw(vld1q_s32(&testvals_int32[i])))
170 .raw());
171 }
172
173 for (size_t i = 0; i < n4; i++) {
174 Check(results_int32[i] == results_int32x4[i]);
175 }
176 }
177 #endif // GEMMLOWP_NEON
178
main()179 int main() {
180 std::vector<int32_t> testvals_int32;
181
182 for (int i = 0; i < 31; i++) {
183 testvals_int32.push_back((1 << i) - 2);
184 testvals_int32.push_back((1 << i) - 1);
185 testvals_int32.push_back((1 << i));
186 testvals_int32.push_back((1 << i) + 1);
187 testvals_int32.push_back((1 << i) + 2);
188 testvals_int32.push_back(-(1 << i) - 2);
189 testvals_int32.push_back(-(1 << i) - 1);
190 testvals_int32.push_back(-(1 << i));
191 testvals_int32.push_back(-(1 << i) + 1);
192 testvals_int32.push_back(-(1 << i) + 2);
193 }
194 testvals_int32.push_back(std::numeric_limits<int32_t>::min());
195 testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 1);
196 testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 2);
197 testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 2);
198 testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 1);
199 testvals_int32.push_back(std::numeric_limits<int32_t>::max());
200
201 uint32_t random = 1;
202 for (int i = 0; i < 1000; i++) {
203 random = random * 1664525 + 1013904223;
204 testvals_int32.push_back(static_cast<int32_t>(random));
205 }
206
207 std::sort(testvals_int32.begin(), testvals_int32.end());
208
209 for (auto a : testvals_int32) {
210 FixedPoint<int32_t, 4> x;
211 x.raw() = a;
212 test_convert(x);
213 }
214
215 test_mul<0, 0>(testvals_int32);
216 test_mul<0, 1>(testvals_int32);
217 test_mul<2, 0>(testvals_int32);
218 test_mul<1, 1>(testvals_int32);
219 test_mul<4, 4>(testvals_int32);
220 test_mul<3, 5>(testvals_int32);
221 test_mul<7, 2>(testvals_int32);
222 test_mul<14, 15>(testvals_int32);
223
224 test_Rescale<0, 0>(testvals_int32);
225 test_Rescale<0, 1>(testvals_int32);
226 test_Rescale<2, 0>(testvals_int32);
227 test_Rescale<4, 4>(testvals_int32);
228 test_Rescale<4, 5>(testvals_int32);
229 test_Rescale<6, 3>(testvals_int32);
230 test_Rescale<13, 9>(testvals_int32);
231
232 test_ExactMulByPot<0, 0>(testvals_int32);
233 test_ExactMulByPot<0, 4>(testvals_int32);
234 test_ExactMulByPot<1, 4>(testvals_int32);
235 test_ExactMulByPot<3, 2>(testvals_int32);
236 test_ExactMulByPot<-4, 5>(testvals_int32);
237 test_ExactMulByPot<-2, 6>(testvals_int32);
238
239 test_exp_on_interval_between_negative_one_quarter_and_0_excl(testvals_int32);
240
241 test_exp_on_negative_values<1>(testvals_int32);
242 test_exp_on_negative_values<2>(testvals_int32);
243 test_exp_on_negative_values<3>(testvals_int32);
244 test_exp_on_negative_values<4>(testvals_int32);
245 test_exp_on_negative_values<5>(testvals_int32);
246 test_exp_on_negative_values<6>(testvals_int32);
247
248 test_one_minus_x_over_one_plus_x_for_x_in_0_1(testvals_int32);
249
250 test_tanh<1>(testvals_int32);
251 test_tanh<2>(testvals_int32);
252 test_tanh<3>(testvals_int32);
253 test_tanh<4>(testvals_int32);
254 test_tanh<5>(testvals_int32);
255 test_tanh<6>(testvals_int32);
256
257 #ifdef GEMMLOWP_NEON
258 test_int32x4(testvals_int32);
259 #endif // GEMMLOWP_NEON
260
261 std::cerr << "All tests passed." << std::endl;
262 }
263