• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
2 
3 #include "test.h"
4 
5 #include "../internal/fixedpoint.h"
6 
7 using namespace gemmlowp;
8 
9 template <int tIntegerBits>
test_convert(FixedPoint<int32_t,tIntegerBits> x)10 void test_convert(FixedPoint<int32_t, tIntegerBits> x) {
11   typedef FixedPoint<int32_t, tIntegerBits> F;
12   F y = ToFixedPoint<int32_t, tIntegerBits>(ToDouble(x));
13   Check(y == x);
14 }
15 
16 template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(FixedPoint<int32_t,tIntegerBits_a> a)17 void test_Rescale(FixedPoint<int32_t, tIntegerBits_a> a) {
18   FixedPoint<int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
19   FixedPoint<int32_t, tIntegerBits_b> expected =
20       ToFixedPoint<int32_t, tIntegerBits_b>(ToDouble(a));
21   Check(actual == expected);
22 }
23 
24 template <int tIntegerBits_a, int tIntegerBits_b>
test_Rescale(const std::vector<int32_t> & testvals_int32)25 void test_Rescale(const std::vector<int32_t>& testvals_int32) {
26   for (auto a : testvals_int32) {
27     FixedPoint<int32_t, tIntegerBits_a> aq;
28     aq.raw() = a;
29     test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
30   }
31 }
32 
33 template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(FixedPoint<int32_t,tIntegerBits_a> a,FixedPoint<int32_t,tIntegerBits_b> b)34 void test_mul(FixedPoint<int32_t, tIntegerBits_a> a,
35               FixedPoint<int32_t, tIntegerBits_b> b) {
36   static const int IntegerBits_ab = tIntegerBits_a + tIntegerBits_b;
37   FixedPoint<int32_t, IntegerBits_ab> ab;
38   ab = a * b;
39   double a_double = ToDouble(a);
40   double b_double = ToDouble(b);
41   double ab_double = a_double * b_double;
42   FixedPoint<int32_t, IntegerBits_ab> expected =
43       ToFixedPoint<int32_t, IntegerBits_ab>(ab_double);
44   int64_t diff = int64_t(ab.raw()) - int64_t(expected.raw());
45   Check(std::abs(diff) <= 1);
46 }
47 
48 template <int tIntegerBits_a, int tIntegerBits_b>
test_mul(const std::vector<int32_t> & testvals_int32)49 void test_mul(const std::vector<int32_t>& testvals_int32) {
50   for (auto a : testvals_int32) {
51     for (auto b : testvals_int32) {
52       FixedPoint<int32_t, tIntegerBits_a> aq;
53       FixedPoint<int32_t, tIntegerBits_b> bq;
54       aq.raw() = a;
55       bq.raw() = b;
56       test_mul(aq, bq);
57     }
58   }
59 }
60 
61 template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(FixedPoint<int32_t,tIntegerBits_a> a)62 void test_ExactMulByPot(FixedPoint<int32_t, tIntegerBits_a> a) {
63   double x = ToDouble(a) * std::pow(2.0, tExponent);
64   double y = ToDouble(ExactMulByPot<tExponent>(a));
65   Check(x == y);
66 }
67 
68 template <int tExponent, int tIntegerBits_a>
test_ExactMulByPot(const std::vector<int32_t> & testvals_int32)69 void test_ExactMulByPot(const std::vector<int32_t>& testvals_int32) {
70   for (auto a : testvals_int32) {
71     FixedPoint<int32_t, tIntegerBits_a> aq;
72     aq.raw() = a;
73     test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
74   }
75 }
76 
test_exp_on_interval_between_negative_one_quarter_and_0_excl(FixedPoint<int32_t,0> a)77 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
78     FixedPoint<int32_t, 0> a) {
79   double a_double = ToDouble(a);
80   double expected = std::exp(a_double);
81   double actual =
82       ToDouble(exp_on_interval_between_negative_one_quarter_and_0_excl(a));
83   double error = expected - actual;
84   Check(std::abs(error) < 3e-7);
85 }
86 
test_exp_on_interval_between_negative_one_quarter_and_0_excl(const std::vector<int32_t> & testvals_int32)87 void test_exp_on_interval_between_negative_one_quarter_and_0_excl(
88     const std::vector<int32_t>& testvals_int32) {
89   for (auto a : testvals_int32) {
90     typedef FixedPoint<int32_t, 0> F;
91     F aq = SaturatingRoundingMultiplyByPOT<-3>(F::FromRaw(a)) -
92            F::ConstantPOT<-3>();
93     test_exp_on_interval_between_negative_one_quarter_and_0_excl(aq);
94   }
95 }
96 
97 template <int tIntegerBits>
test_exp_on_negative_values(FixedPoint<int32_t,tIntegerBits> a)98 void test_exp_on_negative_values(FixedPoint<int32_t, tIntegerBits> a) {
99   double a_double = ToDouble(a);
100   double expected = std::exp(a_double);
101   double actual = ToDouble(exp_on_negative_values(a));
102   double error = expected - actual;
103   Check(std::abs(error) < 3e-7);
104 }
105 
106 template <int tIntegerBits>
test_exp_on_negative_values(const std::vector<int32_t> & testvals_int32)107 void test_exp_on_negative_values(const std::vector<int32_t>& testvals_int32) {
108   for (auto a : testvals_int32) {
109     if (a < 0) {
110       FixedPoint<int32_t, tIntegerBits> aq;
111       aq.raw() = a;
112       test_exp_on_negative_values(aq);
113     }
114   }
115 }
116 
test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t,0> a)117 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(FixedPoint<int32_t, 0> a) {
118   double a_double = ToDouble(a);
119   double expected = (1 - a_double) / (1 + a_double);
120   FixedPoint<int32_t, 0> retval = one_minus_x_over_one_plus_x_for_x_in_0_1(a);
121   double actual = ToDouble(retval);
122   double error = expected - actual;
123   Check(std::abs(error) < 6e-9);
124 }
125 
test_one_minus_x_over_one_plus_x_for_x_in_0_1(const std::vector<int32_t> & testvals_int32)126 void test_one_minus_x_over_one_plus_x_for_x_in_0_1(
127     const std::vector<int32_t>& testvals_int32) {
128   for (auto a : testvals_int32) {
129     if (a > 0) {
130       FixedPoint<int32_t, 0> aq;
131       aq.raw() = a;
132       test_one_minus_x_over_one_plus_x_for_x_in_0_1(aq);
133     }
134   }
135 }
136 
137 template <int tIntegerBits>
test_tanh(FixedPoint<int32_t,tIntegerBits> a)138 void test_tanh(FixedPoint<int32_t, tIntegerBits> a) {
139   double a_double = ToDouble(a);
140   double expected = std::tanh(a_double);
141   double actual = ToDouble(tanh(a));
142   double error = expected - actual;
143   Check(std::abs(error) < 1.5e-7);
144 }
145 
146 template <int tIntegerBits>
test_tanh(const std::vector<int32_t> & testvals_int32)147 void test_tanh(const std::vector<int32_t>& testvals_int32) {
148   for (auto a : testvals_int32) {
149     FixedPoint<int32_t, tIntegerBits> aq;
150     aq.raw() = a;
151     test_tanh(aq);
152   }
153 }
154 
155 #ifdef GEMMLOWP_NEON
test_int32x4(const std::vector<int32_t> & testvals_int32)156 void test_int32x4(const std::vector<int32_t>& testvals_int32) {
157   size_t n = testvals_int32.size();
158   size_t n4 = n - (n % 4);
159   std::vector<int32_t> results_int32(n4);
160   std::vector<int32_t> results_int32x4(n4);
161 
162   for (size_t i = 0; i < n4; i++) {
163     results_int32[i] =
164         tanh(FixedPoint<int32_t, 4>::FromRaw(testvals_int32[i])).raw();
165   }
166   for (size_t i = 0; i < n4; i++) {
167     vst1q_s32(
168         &results_int32x4[i],
169         tanh(FixedPoint<int32x4_t, 4>::FromRaw(vld1q_s32(&testvals_int32[i])))
170             .raw());
171   }
172 
173   for (size_t i = 0; i < n4; i++) {
174     Check(results_int32[i] == results_int32x4[i]);
175   }
176 }
177 #endif  // GEMMLOWP_NEON
178 
main()179 int main() {
180   std::vector<int32_t> testvals_int32;
181 
182   for (int i = 0; i < 31; i++) {
183     testvals_int32.push_back((1 << i) - 2);
184     testvals_int32.push_back((1 << i) - 1);
185     testvals_int32.push_back((1 << i));
186     testvals_int32.push_back((1 << i) + 1);
187     testvals_int32.push_back((1 << i) + 2);
188     testvals_int32.push_back(-(1 << i) - 2);
189     testvals_int32.push_back(-(1 << i) - 1);
190     testvals_int32.push_back(-(1 << i));
191     testvals_int32.push_back(-(1 << i) + 1);
192     testvals_int32.push_back(-(1 << i) + 2);
193   }
194   testvals_int32.push_back(std::numeric_limits<int32_t>::min());
195   testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 1);
196   testvals_int32.push_back(std::numeric_limits<int32_t>::min() + 2);
197   testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 2);
198   testvals_int32.push_back(std::numeric_limits<int32_t>::max() - 1);
199   testvals_int32.push_back(std::numeric_limits<int32_t>::max());
200 
201   uint32_t random = 1;
202   for (int i = 0; i < 1000; i++) {
203     random = random * 1664525 + 1013904223;
204     testvals_int32.push_back(static_cast<int32_t>(random));
205   }
206 
207   std::sort(testvals_int32.begin(), testvals_int32.end());
208 
209   for (auto a : testvals_int32) {
210     FixedPoint<int32_t, 4> x;
211     x.raw() = a;
212     test_convert(x);
213   }
214 
215   test_mul<0, 0>(testvals_int32);
216   test_mul<0, 1>(testvals_int32);
217   test_mul<2, 0>(testvals_int32);
218   test_mul<1, 1>(testvals_int32);
219   test_mul<4, 4>(testvals_int32);
220   test_mul<3, 5>(testvals_int32);
221   test_mul<7, 2>(testvals_int32);
222   test_mul<14, 15>(testvals_int32);
223 
224   test_Rescale<0, 0>(testvals_int32);
225   test_Rescale<0, 1>(testvals_int32);
226   test_Rescale<2, 0>(testvals_int32);
227   test_Rescale<4, 4>(testvals_int32);
228   test_Rescale<4, 5>(testvals_int32);
229   test_Rescale<6, 3>(testvals_int32);
230   test_Rescale<13, 9>(testvals_int32);
231 
232   test_ExactMulByPot<0, 0>(testvals_int32);
233   test_ExactMulByPot<0, 4>(testvals_int32);
234   test_ExactMulByPot<1, 4>(testvals_int32);
235   test_ExactMulByPot<3, 2>(testvals_int32);
236   test_ExactMulByPot<-4, 5>(testvals_int32);
237   test_ExactMulByPot<-2, 6>(testvals_int32);
238 
239   test_exp_on_interval_between_negative_one_quarter_and_0_excl(testvals_int32);
240 
241   test_exp_on_negative_values<1>(testvals_int32);
242   test_exp_on_negative_values<2>(testvals_int32);
243   test_exp_on_negative_values<3>(testvals_int32);
244   test_exp_on_negative_values<4>(testvals_int32);
245   test_exp_on_negative_values<5>(testvals_int32);
246   test_exp_on_negative_values<6>(testvals_int32);
247 
248   test_one_minus_x_over_one_plus_x_for_x_in_0_1(testvals_int32);
249 
250   test_tanh<1>(testvals_int32);
251   test_tanh<2>(testvals_int32);
252   test_tanh<3>(testvals_int32);
253   test_tanh<4>(testvals_int32);
254   test_tanh<5>(testvals_int32);
255   test_tanh<6>(testvals_int32);
256 
257 #ifdef GEMMLOWP_NEON
258   test_int32x4(testvals_int32);
259 #endif  // GEMMLOWP_NEON
260 
261   std::cerr << "All tests passed." << std::endl;
262 }
263