• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <unistd.h>
16 #ifdef __APPLE__
17 #include <sys/time.h>
18 #endif
19 
20 #include <cstdint>
21 #include <cstdlib>
22 #include <ctime>
23 #include <iostream>
24 #include <map>
25 #include <vector>
26 
27 #include "../public/gemmlowp.h"
28 #include "../meta/multi_thread_gemm.h"
29 #include "test.h"
30 
31 #if defined(__arm__) && !defined(GEMMLOWP_NEON)
32 #warning "Building without NEON support on ARM, check your compiler setup!"
33 #endif
34 
time()35 double time() {
36 #ifdef __APPLE__
37   timeval t;
38   gettimeofday(&t, nullptr);
39   return t.tv_sec + 1e-6 * t.tv_usec;
40 #else
41   timespec t;
42   clock_gettime(CLOCK_REALTIME, &t);
43   return t.tv_sec + 1e-9 * t.tv_nsec;
44 #endif
45 }
46 
prepare_test_data(std::uint8_t * data,std::int32_t rows,std::int32_t cols,std::int32_t seed,std::int32_t seed_2)47 void prepare_test_data(std::uint8_t* data, std::int32_t rows, std::int32_t cols,
48                        std::int32_t seed, std::int32_t seed_2) {
49   int32_t value = seed;
50   for (int i = 0; i < rows; ++i) {
51     for (int j = 0; j < cols; ++j) {
52       data[i * cols + j] = static_cast<std::uint8_t>(value);
53       value = ((value * seed_2) + seed) % 256;
54     }
55   }
56 }
57 
58 bool verbose = false;
59 bool quiet = true;
60 
check_result(std::uint8_t * left,std::uint8_t * right,std::uint8_t * result,std::int32_t rows,std::int32_t cols,std::int32_t depth,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t mul_offset,std::int32_t shift)61 void check_result(std::uint8_t* left, std::uint8_t* right, std::uint8_t* result,
62                   std::int32_t rows, std::int32_t cols, std::int32_t depth,
63                   std::int32_t lhs_offset, std::int32_t rhs_offset,
64                   std::int32_t sum_offset, std::int32_t mul_offset,
65                   std::int32_t shift) {
66   std::int32_t rounding = (1 << (shift - 1));
67   std::int32_t wrong = 0;
68   for (int i = 0; i < rows; ++i) {
69     for (int j = 0; j < cols; ++j) {
70       std::int32_t expected = 0;
71       for (int k = 0; k < depth; ++k) {
72         expected +=
73             (static_cast<std::int32_t>(left[depth * i + k]) + lhs_offset) *
74             (static_cast<std::int32_t>(right[depth * j + k]) + rhs_offset);
75       }
76       expected += sum_offset;
77       expected *= mul_offset;
78       expected += rounding;
79       expected = (expected >> shift);
80       if (expected < 0) {
81         expected = 0;
82       } else if (expected > 255) {
83         expected = 255;
84       }
85       expected = static_cast<std::int32_t>(static_cast<std::uint8_t>(expected));
86       std::int32_t actual = static_cast<std::int32_t>(result[i * cols + j]);
87       if (actual == expected) {
88         if (!quiet) {
89           if (verbose) {
90             std::cout << expected << "==" << actual << " ";
91           } else {
92             std::cout << ".";
93           }
94         }
95       } else {
96         if (!quiet) {
97           if (verbose) {
98             std::cout << expected << "!=" << actual << " ";
99           } else {
100             std::cout << "x";
101           }
102         }
103         wrong++;
104       }
105     }
106     if (!quiet) {
107       std::cout << std::endl;
108     }
109   }
110   if (wrong > 0) {
111     std::cout << "Wrong: " << wrong << std::endl;
112   } else {
113     std::cout << "." << std::flush;
114   }
115 }
116 
check_result_f(std::uint8_t * left,std::uint8_t * right,float * result,std::int32_t rows,std::int32_t cols,std::int32_t depth,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)117 void check_result_f(std::uint8_t* left, std::uint8_t* right, float* result,
118                     std::int32_t rows, std::int32_t cols, std::int32_t depth,
119                     std::int32_t lhs_offset, std::int32_t rhs_offset,
120                     float result_offset) {
121   std::int32_t wrong = 0;
122   for (int i = 0; i < rows; ++i) {
123     for (int j = 0; j < cols; ++j) {
124       std::int32_t expected = 0;
125       for (int k = 0; k < depth; ++k) {
126         expected +=
127             (static_cast<std::int32_t>(left[depth * i + k]) + lhs_offset) *
128             (static_cast<std::int32_t>(right[depth * j + k]) + rhs_offset);
129       }
130       float expected_float = static_cast<float>(expected) * result_offset;
131       float actual_float = result[i * cols + j];
132       if (actual_float == expected_float) {
133         if (!quiet) {
134           if (verbose) {
135             std::cout << expected_float << "==" << actual_float << " ";
136           } else {
137             std::cout << ".";
138           }
139         }
140       } else {
141         if (!quiet) {
142           if (verbose) {
143             std::cout << expected_float << "!=" << actual_float << " ";
144           } else {
145             std::cout << "x";
146           }
147         }
148         wrong++;
149       }
150     }
151     if (!quiet) {
152       std::cout << std::endl;
153     }
154   }
155   if (wrong > 0) {
156     std::cout << "Wrong: " << wrong << std::endl;
157   } else {
158     std::cout << "." << std::flush;
159   }
160 }
161 
162 template <typename T>
clear(T * result,std::int32_t rows,std::int32_t cols)163 void clear(T* result, std::int32_t rows, std::int32_t cols) {
164   for (int i = 0; i < rows * cols; ++i) {
165     result[i] = static_cast<T>(0);
166   }
167 }
168 
test(std::uint8_t * scratch,std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,gemmlowp::WorkersPool * pool,std::int32_t pool_size)169 void test(std::uint8_t* scratch, std::uint8_t* lhs, std::uint8_t* rhs,
170           std::int32_t m, std::int32_t n, std::int32_t k, std::uint8_t* result,
171           gemmlowp::WorkersPool* pool, std::int32_t pool_size) {
172   prepare_test_data(lhs, m, k, 11, 13);
173   prepare_test_data(rhs, n, k, 177, 19);
174 
175   clear(result, m, n);
176   gemmlowp::meta::multi_thread_gemm_q8(pool, pool_size, scratch, lhs, rhs, m, n,
177                                        k, -127, -127, 127 * k, 1, 7, result);
178   check_result(lhs, rhs, result, m, n, k, -127, -127, 127 * k, 1, 7);
179 }
180 
test_f(std::uint8_t * scratch,std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,gemmlowp::WorkersPool * pool,std::int32_t pool_size)181 void test_f(std::uint8_t* scratch, std::uint8_t* lhs, std::uint8_t* rhs,
182             std::int32_t m, std::int32_t n, std::int32_t k, float* result,
183             gemmlowp::WorkersPool* pool, std::int32_t pool_size) {
184   prepare_test_data(lhs, m, k, 11, 13);
185   prepare_test_data(rhs, n, k, 177, 19);
186 
187   clear(result, m, n);
188   float scale = 1.0f / 1234567.8f;
189   gemmlowp::meta::multi_thread_gemm_f(pool, pool_size, scratch, lhs, rhs, m, n,
190                                       k, -127, -127, scale, result);
191   check_result_f(lhs, rhs, result, m, n, k, -127, -127, scale);
192 }
193 
main()194 int main() {
195   const std::int32_t min_n = 256;
196   const std::int32_t min_m = 256;
197   const std::int32_t min_k = 256;
198 
199   const std::int32_t max_n = 1024;
200   const std::int32_t max_m = 1024;
201   const std::int32_t max_k = 512;
202 
203   std::uint8_t* left = new std::uint8_t[max_m * max_k];
204   std::uint8_t* right = new std::uint8_t[max_n * max_k];
205   std::uint8_t* result = new std::uint8_t[max_m * max_n];
206   float* result_float = new float[max_m * max_n];
207   std::uint8_t* scratch = new std::uint8_t[1024 * 1024 * 64];
208 
209   gemmlowp::WorkersPool pool;
210   pool.CreateWorkers(3);
211 
212   std::cout << "Quantized 8 bit." << std::endl << std::flush;
213 
214   for (int m = min_m; m < max_m; m += 128) {
215     for (int n = min_n; n < max_n; n += 128) {
216       for (int k = min_k; k < max_k; k += 13) {
217         test(scratch, left, right, m, n, k, result, &pool, 4);
218       }
219     }
220   }
221 
222   std::cout << std::endl << "Floats." << std::endl << std::flush;
223 
224   for (int m = min_m; m < max_m; m += 128) {
225     for (int n = min_n; n < max_n; n += 128) {
226       for (int k = min_k; k < max_k; k += 13) {
227         test_f(scratch, left, right, m, n, k, result_float, &pool, 4);
228       }
229     }
230   }
231 
232   std::cout << std::endl << "Done." << std::endl << std::flush;
233 }
234