1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <unistd.h>
16 #ifdef __APPLE__
17 #include <sys/time.h>
18 #endif
19
20 #include <cstdint>
21 #include <cstdlib>
22 #include <ctime>
23 #include <iomanip>
24 #include <iostream>
25 #include <map>
26 #include <memory>
27 #include <vector>
28
29 #include "multi_thread_gemm.h"
30 #include "quantized_mul_kernels.h"
31 #include "single_thread_gemm.h"
32 #include "streams.h"
33
34 #define LHS_OFFSET (-127)
35 #define RHS_OFFSET (-127)
36 #define SUM_OFFSET (127)
37 #define MUL_OFFSET (1)
38 #define SHIFT (7)
39 #define FLOAT_SCALE (0.333f)
40
41 using namespace gemmlowp::meta;
42
43 // Input, output & kernel setups.
44
45 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, ColumnMajorWithSum,
46 QuantizedStaticPreprocessed, RowMajor>
47 ParamsColumnMajor;
48
49 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum,
50 QuantizedStaticPreprocessed, RowMajor>
51 ParamsRowMajor;
52
53 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, ColumnMajorWithSum,
54 QuantizedStaticPreprocessedAsFloat, RowMajor>
55 ParamsColumnMajorAsFloat;
56
57 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
58 QuantizedStaticPreprocessedAsFloat, RowMajor>
59 ParamsRowMajorAsFloat;
60
61 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, ColumnMajorWithSum,
62 QuantizedStaticPreprocessedAsInt32, RowMajor>
63 ParamsColumnMajorAsInt32;
64
65 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum,
66 QuantizedStaticPreprocessedAsInt32, RowMajor>
67 ParamsRowMajorAsInt32;
68
69 typedef gemmlowp::WorkersPool Pool;
70 typedef SimpleContext<gemmlowp::WorkersPool> Context;
71
72 #ifdef LHS_PACK
73 typedef GemmExecutorPackLHSCacheFriendly<> Executor;
74 #else
75 typedef GemmExecutorPackRHSCacheFriendly<> Executor;
76 #endif
77
78 // Testing helper functions.
79
prepare_test_data(std::uint8_t * data,std::int32_t rows,std::int32_t cols,std::int32_t seed,std::int32_t seed_2)80 void prepare_test_data(std::uint8_t* data, std::int32_t rows, std::int32_t cols,
81 std::int32_t seed, std::int32_t seed_2) {
82 std::int32_t value = seed;
83 for (int i = 0; i < rows * cols; ++i) {
84 data[i] = static_cast<std::uint8_t>(value);
85 value = ((value * seed_2) + seed) % 256;
86 }
87 }
88
89 template <typename CLEAR_TYPE>
clear(int rows,int cols,CLEAR_TYPE * data)90 void clear(int rows, int cols, CLEAR_TYPE* data) {
91 for (int i = 0; i < rows * cols; ++i) {
92 data[i] = 0;
93 }
94 }
95
check_row_row(std::uint8_t * lhs,std::uint8_t * rhs,std::uint8_t * results,int rows,int cols,int depth)96 bool check_row_row(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
97 int cols, int depth) {
98 int wrong = 0;
99 int rounding = (1 << (SHIFT - 1));
100 for (int i = 0; i < rows; ++i) {
101 for (int j = 0; j < cols; ++j) {
102 int expected = 0;
103 for (int k = 0; k < depth; ++k) {
104 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
105 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
106 }
107 expected += SUM_OFFSET * depth;
108 expected *= MUL_OFFSET;
109 expected += rounding;
110 expected = (expected >> SHIFT);
111 if (expected < 0) {
112 expected = 0;
113 } else if (expected > 255) {
114 expected = 255;
115 }
116 expected = static_cast<int>(static_cast<std::uint8_t>(expected));
117 int actual = static_cast<int>(results[i * cols + j]);
118 if (actual != expected) {
119 std::cout << "Wrong @" << i << "x" << j << " : " << actual
120 << " != " << expected << std::endl;
121 wrong++;
122 }
123 }
124 }
125 if (wrong != 0) {
126 std::cout << wrong << "/" << (rows * cols) << std::endl;
127 }
128 return wrong == 0;
129 }
130
check_row_col(std::uint8_t * lhs,std::uint8_t * rhs,std::uint8_t * results,int rows,int cols,int depth)131 bool check_row_col(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
132 int cols, int depth) {
133 int wrong = 0;
134 int rounding = (1 << (SHIFT - 1));
135 for (int i = 0; i < rows; ++i) {
136 for (int j = 0; j < cols; ++j) {
137 int expected = 0;
138 for (int k = 0; k < depth; ++k) {
139 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
140 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
141 }
142 expected += SUM_OFFSET * depth;
143 expected *= MUL_OFFSET;
144 expected += rounding;
145 expected = (expected >> SHIFT);
146 if (expected < 0) {
147 expected = 0;
148 } else if (expected > 255) {
149 expected = 255;
150 }
151 expected = static_cast<int>(static_cast<std::uint8_t>(expected));
152 int actual = static_cast<int>(results[i * cols + j]);
153 if (actual != expected) {
154 wrong++;
155 }
156 }
157 }
158 return wrong == 0;
159 }
160
check_row_row_f(std::uint8_t * lhs,std::uint8_t * rhs,float * results,int rows,int cols,int depth)161 bool check_row_row_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
162 int cols, int depth) {
163 int wrong = 0;
164 for (int i = 0; i < rows; ++i) {
165 for (int j = 0; j < cols; ++j) {
166 int expected = 0;
167 for (int k = 0; k < depth; ++k) {
168 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
169 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
170 }
171 float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
172 float actual = results[i * cols + j];
173 if (actual != expected_float) {
174 wrong++;
175 }
176 }
177 }
178 return wrong == 0;
179 }
180
check_row_col_f(std::uint8_t * lhs,std::uint8_t * rhs,float * results,int rows,int cols,int depth)181 bool check_row_col_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
182 int cols, int depth) {
183 int wrong = 0;
184 for (int i = 0; i < rows; ++i) {
185 for (int j = 0; j < cols; ++j) {
186 int expected = 0;
187 for (int k = 0; k < depth; ++k) {
188 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
189 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
190 }
191 float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
192 float actual = results[i * cols + j];
193 if (actual != expected_float) {
194 wrong++;
195 }
196 }
197 }
198 return wrong == 0;
199 }
200
check_row_row_i32(std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t * results,int rows,int cols,int depth)201 bool check_row_row_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
202 int cols, int depth) {
203 int wrong = 0;
204 for (int i = 0; i < rows; ++i) {
205 for (int j = 0; j < cols; ++j) {
206 int expected = 0;
207 for (int k = 0; k < depth; ++k) {
208 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
209 (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
210 }
211 int actual = results[i * cols + j];
212 if (actual != expected) {
213 wrong++;
214 }
215 }
216 }
217 return wrong == 0;
218 }
219
check_row_col_i32(std::uint8_t * lhs,std::uint8_t * rhs,std::int32_t * results,int rows,int cols,int depth)220 bool check_row_col_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
221 int cols, int depth) {
222 int wrong = 0;
223 for (int i = 0; i < rows; ++i) {
224 for (int j = 0; j < cols; ++j) {
225 int expected = 0;
226 for (int k = 0; k < depth; ++k) {
227 expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
228 (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
229 }
230 int actual = results[i * cols + j];
231 if (actual != expected) {
232 wrong++;
233 }
234 }
235 }
236 return wrong == 0;
237 }
238
239 template <typename PARAMS, typename RESULT_TYPE>
setup_params(std::uint8_t * lhs,std::uint8_t * rhs,RESULT_TYPE * result,std::uint8_t * scratch,PARAMS * params)240 void setup_params(std::uint8_t* lhs, std::uint8_t* rhs, RESULT_TYPE* result,
241 std::uint8_t* scratch, PARAMS* params) {
242 params->lhs = lhs;
243 params->rhs = rhs;
244 params->result = result;
245 params->scratch = scratch;
246
247 params->left_stream.multiplicative_sum_offset = RHS_OFFSET;
248 params->left_stream.additive_sum_offset = 0;
249
250 params->right_stream.multiplicative_sum_offset = LHS_OFFSET;
251 params->right_stream.additive_sum_offset = 0;
252 }
253
setup_row_row(int m,int n,int k,ParamsRowMajor * params)254 void setup_row_row(int m, int n, int k, ParamsRowMajor* params) {
255 params->m = m;
256 params->n = n;
257 params->k = k;
258 params->left_stream.count = k;
259 params->left_stream.stride = k;
260 params->left_stream.additive_sum_offset =
261 SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
262 params->right_stream.count = k;
263 params->right_stream.stride = k;
264 params->fused_kernel.kernel.count = k;
265 params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
266 params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
267 params->fused_kernel.kernel.shift = -SHIFT;
268 params->fused_kernel.output_stream.stride = n;
269 }
270
setup_row_col(int m,int n,int k,ParamsColumnMajor * params)271 void setup_row_col(int m, int n, int k, ParamsColumnMajor* params) {
272 params->m = m;
273 params->n = n;
274 params->k = k;
275 params->left_stream.count = k;
276 params->left_stream.stride = k;
277 params->left_stream.additive_sum_offset =
278 SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
279 params->right_stream.count = k;
280 params->right_stream.stride = n;
281 params->fused_kernel.kernel.count = k;
282 params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
283 params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
284 params->fused_kernel.kernel.shift = -SHIFT;
285 params->fused_kernel.output_stream.stride = n;
286 }
287
setup_row_row_f(int m,int n,int k,ParamsRowMajorAsFloat * params)288 void setup_row_row_f(int m, int n, int k, ParamsRowMajorAsFloat* params) {
289 params->m = m;
290 params->n = n;
291 params->k = k;
292 params->left_stream.count = k;
293 params->left_stream.stride = k;
294 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
295 params->right_stream.count = k;
296 params->right_stream.stride = k;
297 params->fused_kernel.kernel.count = k;
298 params->fused_kernel.kernel.scale = FLOAT_SCALE;
299 params->fused_kernel.output_stream.stride = n * sizeof(float);
300 }
301
setup_row_col_f(int m,int n,int k,ParamsColumnMajorAsFloat * params)302 void setup_row_col_f(int m, int n, int k, ParamsColumnMajorAsFloat* params) {
303 params->m = m;
304 params->n = n;
305 params->k = k;
306 params->left_stream.count = k;
307 params->left_stream.stride = k;
308 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
309 params->right_stream.count = k;
310 params->right_stream.stride = n;
311 params->fused_kernel.kernel.count = k;
312 params->fused_kernel.kernel.scale = FLOAT_SCALE;
313 params->fused_kernel.output_stream.stride = n * sizeof(float);
314 }
315
setup_row_row_i32(int m,int n,int k,ParamsRowMajorAsInt32 * params)316 void setup_row_row_i32(int m, int n, int k, ParamsRowMajorAsInt32* params) {
317 params->m = m;
318 params->n = n;
319 params->k = k;
320 params->left_stream.count = k;
321 params->left_stream.stride = k;
322 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
323 params->right_stream.count = k;
324 params->right_stream.stride = k;
325 params->fused_kernel.kernel.count = k;
326 params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
327 }
328
setup_row_col_i32(int m,int n,int k,ParamsColumnMajorAsInt32 * params)329 void setup_row_col_i32(int m, int n, int k, ParamsColumnMajorAsInt32* params) {
330 params->m = m;
331 params->n = n;
332 params->k = k;
333 params->left_stream.count = k;
334 params->left_stream.stride = k;
335 params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
336 params->right_stream.count = k;
337 params->right_stream.stride = n;
338 params->fused_kernel.kernel.count = k;
339 params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
340 }
341
main()342 int main() {
343 ParamsRowMajor params_row;
344 ParamsColumnMajor params_col;
345 ParamsRowMajorAsFloat params_row_f;
346 ParamsColumnMajorAsFloat params_col_f;
347 ParamsRowMajorAsInt32 params_row_i32;
348 ParamsColumnMajorAsInt32 params_col_i32;
349
350 std::unique_ptr<std::uint8_t> lhs(new std::uint8_t[1024 * 1024]);
351 std::unique_ptr<std::uint8_t> rhs(new std::uint8_t[1024 * 1024]);
352 std::unique_ptr<std::uint8_t> result(new std::uint8_t[1024 * 1024]);
353 std::unique_ptr<float> result_f(new float[1024 * 1024]);
354 std::unique_ptr<std::int32_t> result_i32(new std::int32_t[1024 * 1024]);
355 std::unique_ptr<std::uint8_t> scratch(new std::uint8_t[4048 * 1024]);
356
357 setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), ¶ms_row);
358 setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), ¶ms_col);
359 setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
360 ¶ms_row_f);
361 setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
362 ¶ms_col_f);
363 setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
364 ¶ms_row_i32);
365 setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
366 ¶ms_col_i32);
367
368 Pool pool;
369 Context context(4, &pool);
370
371 for (int i = 1; i < 16; ++i) {
372 for (int j = 1; j < 16; ++j) {
373 for (int k = 1; k < 24; ++k) {
374 prepare_test_data(lhs.get(), i, k, 11, 13);
375 prepare_test_data(rhs.get(), j, k, 13, 17);
376
377 clear(i, j, result.get());
378 setup_row_row(i, j, k, ¶ms_row);
379 Gemm<Executor, ParamsRowMajor, 2, 4, 8>(params_row);
380 if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
381 std::cout << "Row: " << i << "x" << j << "x" << k << " : ERROR"
382 << std::endl;
383 std::cout << "Exiting." << std::endl;
384 std::exit(1);
385 }
386
387 clear(i, j, result.get());
388 setup_row_col(i, j, k, ¶ms_col);
389 Gemm<Executor, ParamsColumnMajor, 2, 4, 8>(params_col);
390 if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
391 std::cout << "Column: " << i << "x" << j << "x" << k << " : ERROR"
392 << std::endl;
393 std::cout << "Exiting." << std::endl;
394 std::exit(1);
395 }
396
397 clear(i, j, result_f.get());
398 setup_row_row_f(i, j, k, ¶ms_row_f);
399 Gemm<Executor, ParamsRowMajorAsFloat, 2, 4, 8>(params_row_f);
400 if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
401 std::cout << "RowAsFloat: " << i << "x" << j << "x" << k << " : ERROR"
402 << std::endl;
403 std::cout << "Exiting." << std::endl;
404 std::exit(1);
405 }
406
407 clear(i, j, result_f.get());
408 setup_row_col_f(i, j, k, ¶ms_col_f);
409 Gemm<Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(params_col_f);
410 if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
411 std::cout << "ColumnAsFloat: " << i << "x" << j << "x" << k
412 << " : ERROR" << std::endl;
413 std::cout << "Exiting." << std::endl;
414 std::exit(1);
415 }
416
417 clear(i, j, result_i32.get());
418 setup_row_row_i32(i, j, k, ¶ms_row_i32);
419 Gemm<Executor, ParamsRowMajorAsInt32, 2, 4, 8>(params_row_i32);
420 if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
421 k)) {
422 std::cout << "RowAsInt32: " << i << "x" << j << "x" << k << " : ERROR"
423 << std::endl;
424 std::cout << "Exiting." << std::endl;
425 std::exit(1);
426 }
427
428 clear(i, j, result_i32.get());
429 setup_row_col_i32(i, j, k, ¶ms_col_i32);
430 Gemm<Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(params_col_i32);
431 if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
432 k)) {
433 std::cout << "ColumnAsInt32: " << i << "x" << j << "x" << k
434 << " : ERROR" << std::endl;
435 std::cout << "Exiting." << std::endl;
436 std::exit(1);
437 }
438 }
439 }
440 }
441
442 for (int i = 1; i < 1024; i += 211) {
443 for (int j = 1; j < 1024; j += 211) {
444 for (int k = 8; k < 1024; k += 111) {
445 prepare_test_data(lhs.get(), i, k, 11, 13);
446 prepare_test_data(rhs.get(), j, k, 13, 17);
447
448 clear(i, j, result.get());
449 setup_row_row(i, j, k, ¶ms_row);
450 MultiThreadGemm<Context, Executor, ParamsRowMajor, 2, 4, 8>(&context,
451 params_row);
452 if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
453 std::cout << "Row(MT): " << i << "x" << j << "x" << k << " : ERROR"
454 << std::endl;
455 std::cout << "Exiting." << std::endl;
456 std::exit(1);
457 }
458
459 clear(i, j, result.get());
460 setup_row_col(i, j, k, ¶ms_col);
461 MultiThreadGemm<Context, Executor, ParamsColumnMajor, 2, 4, 8>(
462 &context, params_col);
463 if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
464 std::cout << "Column(MT): " << i << "x" << j << "x" << k << " : ERROR"
465 << std::endl;
466 std::cout << "Exiting." << std::endl;
467 std::exit(1);
468 }
469
470 clear(i, j, result_f.get());
471 setup_row_row_f(i, j, k, ¶ms_row_f);
472 MultiThreadGemm<Context, Executor, ParamsRowMajorAsFloat, 2, 4, 8>(
473 &context, params_row_f);
474 if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
475 std::cout << "RowAsFloat(MT): " << i << "x" << j << "x" << k
476 << " : ERROR" << std::endl;
477 std::cout << "Exiting." << std::endl;
478 std::exit(1);
479 }
480
481 clear(i, j, result_f.get());
482 setup_row_col_f(i, j, k, ¶ms_col_f);
483 MultiThreadGemm<Context, Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(
484 &context, params_col_f);
485 if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
486 std::cout << "ColumnAsFloat(MT): " << i << "x" << j << "x" << k
487 << " : ERROR" << std::endl;
488 std::cout << "Exiting." << std::endl;
489 std::exit(1);
490 }
491
492 clear(i, j, result_i32.get());
493 setup_row_row_i32(i, j, k, ¶ms_row_i32);
494 MultiThreadGemm<Context, Executor, ParamsRowMajorAsInt32, 2, 4, 8>(
495 &context, params_row_i32);
496 if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
497 k)) {
498 std::cout << "RowAsInt32(MT): " << i << "x" << j << "x" << k
499 << " : ERROR" << std::endl;
500 std::cout << "Exiting." << std::endl;
501 std::exit(1);
502 }
503
504 clear(i, j, result_i32.get());
505 setup_row_col_i32(i, j, k, ¶ms_col_i32);
506 MultiThreadGemm<Context, Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(
507 &context, params_col_i32);
508 if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
509 k)) {
510 std::cout << "ColumnAsInt32(MT): " << i << "x" << j << "x" << k
511 << " : ERROR" << std::endl;
512 std::cout << "Exiting." << std::endl;
513 std::exit(1);
514 }
515 }
516 }
517 }
518
519 std::cout << "OK." << std::endl;
520 return 0;
521 }
522