1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test.h"
16
17 #include <unistd.h>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32
33 namespace gemmlowp {
34
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const uint8_t * a,int32_t a_offset,int lda,const uint8_t * b,int32_t b_offset,int ldb,uint8_t * c,int32_t c_offset,int32_t c_mult_int,int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36 bool transpose_c, int m, int n, int k,
37 const uint8_t* a, int32_t a_offset, int lda,
38 const uint8_t* b, int32_t b_offset, int ldb,
39 uint8_t* c, int32_t c_offset, int32_t c_mult_int,
40 int32_t c_shift, int ldc) {
41 assert((c_shift >= 0) && (c_shift <= 32));
42
43 assert(a != nullptr);
44 assert(b != nullptr);
45 assert(c != nullptr);
46
47 int a_i_stride;
48 int a_l_stride;
49 if (transpose_a) {
50 a_i_stride = lda;
51 a_l_stride = 1;
52 } else {
53 a_i_stride = 1;
54 a_l_stride = lda;
55 }
56 int b_j_stride;
57 int b_l_stride;
58 if (transpose_b) {
59 b_j_stride = 1;
60 b_l_stride = ldb;
61 } else {
62 b_j_stride = ldb;
63 b_l_stride = 1;
64 }
65 int c_i_stride;
66 int c_j_stride;
67 if (transpose_c) {
68 c_i_stride = ldc;
69 c_j_stride = 1;
70 } else {
71 c_i_stride = 1;
72 c_j_stride = ldc;
73 }
74 int i, j, l;
75
76 const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
77
78 for (j = 0; j < n; j++) {
79 for (i = 0; i < m; i++) {
80 int32_t total = 0;
81 for (l = 0; l < k; l++) {
82 const int a_index = i * a_i_stride + l * a_l_stride;
83 const uint8_t a_as_byte = a[a_index];
84 const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset;
85 const int b_index = j * b_j_stride + l * b_l_stride;
86 const uint8_t b_as_byte = b[b_index];
87 const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset;
88 const int32_t mult_as_int = a_as_int * b_as_int;
89 total += mult_as_int;
90 }
91 int32_t output =
92 (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
93 if (output > 255) {
94 output = 255;
95 }
96 if (output < 0) {
97 output = 0;
98 }
99 const int c_index = i * c_i_stride + j * c_j_stride;
100 c[c_index] = static_cast<uint8_t>(output);
101 }
102 }
103 }
104
105 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
106 // interface, so we can use the same testing code to test all of them
107
108 template <typename Kernel, typename Scalar, typename tBitDepthParams>
109 struct SingleThreadGemmWrapper {
110 typedef tBitDepthParams BitDepthParams;
111
Namegemmlowp::SingleThreadGemmWrapper112 static const char* Name() {
113 static char buf[256];
114 snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
115 return buf;
116 }
117
118 typedef SingleThreadGemmContext Context;
119
120 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper121 static void Gemm(Context* context,
122 const MatrixMap<const Scalar, LhsOrder>& lhs,
123 const MatrixMap<const Scalar, RhsOrder>& rhs,
124 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
125 int rhs_offset, int result_offset, int result_mult_int,
126 int result_shift) {
127 const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
128 const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
129 SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
130 LhsOrder, RhsOrder, ResultOrder,
131 OffsetColDup, OffsetRowDup>(
132 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
133 rhs_offset_vector,
134 MakeStandardOutputPipeline(result_offset, result_mult_int,
135 result_shift));
136 }
137 };
138
139 template <typename Kernel, typename Scalar, typename tBitDepthParams>
140 struct MultiThreadGemmWrapper {
141 typedef tBitDepthParams BitDepthParams;
142
Namegemmlowp::MultiThreadGemmWrapper143 static const char* Name() {
144 static char buf[256];
145 snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
146 return buf;
147 }
148
149 typedef MultiThreadGemmContext Context;
150
151 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper152 static void Gemm(Context* context,
153 const MatrixMap<const Scalar, LhsOrder>& lhs,
154 const MatrixMap<const Scalar, RhsOrder>& rhs,
155 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
156 int rhs_offset, int result_offset, int result_mult_int,
157 int result_shift) {
158 const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
159 const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
160 MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
161 LhsOrder, RhsOrder, ResultOrder,
162 OffsetColDup, OffsetRowDup>(
163 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
164 rhs_offset_vector,
165 MakeStandardOutputPipeline(result_offset, result_mult_int,
166 result_shift));
167 }
168 };
169
170 template <typename Scalar, typename tBitDepthParams>
171 struct PublicGemmWrapper {
172 typedef tBitDepthParams BitDepthParams;
173
Namegemmlowp::PublicGemmWrapper174 static const char* Name() { return "public Gemm"; }
175
176 typedef GemmContext Context;
177
178 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper179 static void Gemm(Context* context,
180 const MatrixMap<const Scalar, LhsOrder>& lhs,
181 const MatrixMap<const Scalar, RhsOrder>& rhs,
182 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
183 int rhs_offset, int result_offset, int result_mult_int,
184 int result_shift) {
185 gemmlowp::Gemm<uint8_t, BitDepthParams, LhsOrder, RhsOrder, ResultOrder>(
186 context, lhs, rhs, result, lhs_offset, rhs_offset, result_offset,
187 result_mult_int, result_shift);
188 }
189 };
190
191 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
192 struct BitDepthParamsForSettings {};
193
194 template <>
195 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
196 : DefaultL8R8BitDepthParams {};
197
198 template <>
199 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
200 : DefaultL7R5BitDepthParams {};
201
202 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
203 struct EightBitIntGemmWrapper {
204 typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
205
Namegemmlowp::EightBitIntGemmWrapper206 static const char* Name() { return "EightBitIntGemm"; }
207
208 typedef void Context;
209
210 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper211 static void Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
212 const MatrixMap<const Scalar, RhsOrder>& rhs,
213 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
214 int rhs_offset, int result_offset, int result_mult_int,
215 int result_shift) {
216 const bool transpose_c = ResultOrder == MapOrder::RowMajor;
217 const bool transpose_a = LhsOrder == MapOrder::RowMajor;
218 const bool transpose_b = RhsOrder == MapOrder::RowMajor;
219 eight_bit_int_gemm::EightBitIntGemm(
220 transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
221 lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
222 rhs_offset, rhs.stride(), result->data(), result_offset,
223 result_mult_int, result_shift, result->stride(), BitDepth);
224 }
225 };
226
227 template <typename Scalar>
228 struct ReferenceEightBitIntGemmWrapper {
229 typedef DefaultL8R8BitDepthParams BitDepthParams;
230
Namegemmlowp::ReferenceEightBitIntGemmWrapper231 static const char* Name() { return "ReferenceEightBitIntGemm"; }
232
233 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper234 static void Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
235 const MatrixMap<const Scalar, LhsOrder>& lhs,
236 const MatrixMap<const Scalar, RhsOrder>& rhs,
237 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
238 int rhs_offset, int result_offset, int result_mult_int,
239 int result_shift) {
240 ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
241 rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
242 lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
243 result->data(), result_offset, result_mult_int,
244 result_shift, result->stride());
245 }
246 };
247
OrderName(MapOrder order)248 const char* OrderName(MapOrder order) {
249 return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
250 }
251
252 struct ResultStats {
ResultStatsgemmlowp::ResultStats253 ResultStats()
254 : count(0),
255 med_val(0),
256 mean_signed_diff(0),
257 med_signed_diff(0),
258 med_unsigned_diff(0),
259 max_unsigned_diff(0) {}
260
261 int count;
262 int med_val;
263 float mean_signed_diff;
264 int med_signed_diff;
265 int med_unsigned_diff;
266 int max_unsigned_diff;
267
268 std::vector<int> count_diff_by_pot_slice;
269 };
270
GetResultStats(const uint8_t * actual,const uint8_t * expected,size_t count,ResultStats * stats)271 void GetResultStats(const uint8_t* actual, const uint8_t* expected,
272 size_t count, ResultStats* stats) {
273 std::vector<uint8_t> results;
274 std::vector<int16_t> signed_diffs;
275 std::vector<uint8_t> unsigned_diffs;
276 int64_t signed_diffs_sum = 0;
277 for (size_t i = 0; i < count; i++) {
278 results.push_back(actual[i]);
279 int16_t signed_diff = actual[i] - expected[i];
280 signed_diffs.push_back(signed_diff);
281 unsigned_diffs.push_back(std::abs(signed_diff));
282 signed_diffs_sum += signed_diff;
283 }
284
285 std::sort(results.begin(), results.end());
286 std::sort(signed_diffs.begin(), signed_diffs.end());
287 std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
288
289 const size_t middle = count / 2;
290
291 stats->count = count;
292 stats->med_val = results[middle];
293 stats->mean_signed_diff = float(signed_diffs_sum) / count;
294 stats->med_signed_diff = signed_diffs[middle];
295 stats->med_unsigned_diff = unsigned_diffs[middle];
296 stats->max_unsigned_diff = unsigned_diffs.back();
297
298 // Size 9 for 9 different POT values: 2^0, ..., 2^8
299 stats->count_diff_by_pot_slice.resize(9);
300 auto cur = unsigned_diffs.begin();
301 size_t checksum = 0;
302 for (int exponent = 0; exponent < 9; exponent++) {
303 int pot = 1 << exponent;
304 auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
305 checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
306 cur = next;
307 }
308 assert(checksum == count);
309 }
310
311 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds312 ResultStatsBounds()
313 : mean_signed_diff(0),
314 med_signed_diff(0),
315 med_unsigned_diff(0),
316 max_unsigned_diff(0) {}
317
318 float mean_signed_diff;
319 int med_signed_diff;
320 int med_unsigned_diff;
321 int max_unsigned_diff;
322 };
323
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)324 bool CheckResultStatsBounds(const ResultStats& stats,
325 const ResultStatsBounds& bounds) {
326 return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
327 stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
328 std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
329 std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
330 }
331
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)332 void ReportResultStats(const ResultStats& stats,
333 const ResultStatsBounds& bounds) {
334 printf(" number of matrix entries: %d\n", stats.count);
335 printf(" median value: %d\n", stats.med_val);
336 printf(" median unsigned diff: %d (tolerating %d)\n",
337 stats.med_unsigned_diff, bounds.med_unsigned_diff);
338 printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
339 bounds.max_unsigned_diff);
340 printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
341 bounds.med_signed_diff);
342 printf(" mean signed diff: %.3g (tolerating %.3g)\n",
343 stats.mean_signed_diff, bounds.mean_signed_diff);
344
345 printf("No error: %.2f %% of entries\n",
346 100.f * stats.count_diff_by_pot_slice[0] / stats.count);
347 for (int exponent = 1; exponent < 9; exponent++) {
348 printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
349 (1 << exponent) - 1,
350 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
351 }
352 }
353
354 // Our approach to choosing result_shift values for testing, is bisection.
355 // This function takes an interval, [result_shift_min .. result_shift_max].
356 // If too much saturation occurred in either direction, it bisects accordingly,
357 // recursing until the interval contains only one value.
358 // The primary reason why we prefer this over computing optimal shift values,
359 // is that we actually want to exercise some saturation, as there is nontrivial
360 // code handling that in gemmlowp.
361 // Secondarily, this is faster than computing optimal shifts, since in 90% of
362 // cases the first-tried shift value 16 turns out to be good enough.
363 template <typename GemmWrapper, typename LhsType, typename RhsType,
364 typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)365 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
366 const RhsType& rhs, ResultType* result, int lhs_offset,
367 int rhs_offset, int result_offset, int result_mult_int,
368 int result_shift_min, int result_shift_max) {
369 const int rows = lhs.rows();
370 const int cols = rhs.cols();
371 Check(lhs.cols() == rhs.rows());
372 const int depth = lhs.cols();
373
374 const int result_shift = (result_shift_min + result_shift_max) / 2;
375
376 GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(), &result->map(),
377 lhs_offset, rhs_offset, result_offset, result_mult_int,
378 result_shift);
379
380 typedef typename ResultType::Scalar Scalar;
381 static const MapOrder kLhsOrder = LhsType::kOrder;
382 static const MapOrder kRhsOrder = RhsType::kOrder;
383 static const MapOrder kResultOrder = ResultType::kOrder;
384 ResultType ref_result(rows, cols);
385 const bool transpose_c = kResultOrder == MapOrder::RowMajor;
386 const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
387 const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
388 ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
389 transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
390 &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
391 result_shift);
392
393 typedef typename GemmWrapper::BitDepthParams BitDepthParams;
394
395 ResultStats stats;
396 GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
397
398 // Adjust shifts until we get meaningful results
399 int new_result_shift_min = result_shift_min;
400 int new_result_shift_max = result_shift_max;
401 bool retry = false;
402
403 if (stats.med_val < 32) {
404 new_result_shift_max = (result_shift_min + result_shift_max) / 2;
405 retry = true;
406 }
407
408 if (stats.med_val > 224) {
409 new_result_shift_min = (result_shift_min + result_shift_max) / 2;
410 retry = true;
411 }
412
413 if (retry) {
414 if (result_shift_min != result_shift_max) {
415 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
416 rhs_offset, result_offset, result_mult_int,
417 new_result_shift_min, new_result_shift_max);
418 }
419 return;
420 }
421
422 ResultStatsBounds bounds;
423
424 if (BitDepthParams::LhsBitDepth::kBits < 8 ||
425 BitDepthParams::RhsBitDepth::kBits < 8) {
426 // We have very lax requirements on unsigned diff.
427 // We have tighter requirements on signed diff (bias), but only
428 // if the matrix is large enough for things to average out.
429 // For very small sizes, we... basically don't test anything.
430 // The problem is that this test uses unrealistic combinations of
431 // result_mult_int
432 // and result_shift, resulting in potentially wild requantization artifacts
433 // on small GEMMs.
434 int adjust_for_small_sizes = 1000 / (rows * cols);
435 bounds.max_unsigned_diff =
436 std::max(stats.med_val / 2, adjust_for_small_sizes);
437 bounds.med_unsigned_diff =
438 std::max(stats.med_val / 8, adjust_for_small_sizes);
439 bounds.med_signed_diff = std::max(2, adjust_for_small_sizes);
440 bounds.mean_signed_diff = std::max(2, adjust_for_small_sizes);
441 }
442
443 // Check results
444 const bool good = CheckResultStatsBounds(stats, bounds);
445
446 printf(
447 "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
448 good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
449 OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
450 lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
451
452 if (!good) {
453 ReportResultStats(stats, bounds);
454
455 int bad_coeffs_printed = 0;
456 for (int c = 0; c < result->cols() && bad_coeffs_printed < 20; c++) {
457 for (int r = 0; r < result->rows() && bad_coeffs_printed < 20; r++) {
458 if (ref_result(r, c) != (*result)(r, c)) {
459 printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
460 ref_result(r, c), (*result)(r, c));
461 bad_coeffs_printed++;
462 }
463 }
464 }
465 }
466
467 Check(good);
468 }
469
470 template <typename GemmWrapper, typename LhsType, typename RhsType,
471 typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)472 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
473 const RhsType& rhs, ResultType* result, int lhs_offset,
474 int rhs_offset, int result_offset, int result_mult_int) {
475 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
476 result_offset, result_mult_int, 0, 32);
477 }
478
479 enum class WhatParamsToTest {
480 All,
481 OnlyGenericCase,
482 };
483
484 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
485 MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)486 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
487 int cols, WhatParamsToTest params_to_test) {
488 typedef std::uint8_t Scalar;
489 typedef Matrix<Scalar, LhsOrder> LhsType;
490 LhsType lhs(rows, depth);
491 MakeRandom(&lhs, 8);
492 typedef Matrix<Scalar, RhsOrder> RhsType;
493 RhsType rhs(depth, cols);
494 MakeRandom(&rhs, 8);
495 typedef Matrix<Scalar, ResultOrder> ResultType;
496 ResultType result(rows, cols);
497 MakeZero(&result);
498
499 if (params_to_test == WhatParamsToTest::All) {
500 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
501 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
502 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
503 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
504 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
505 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
506 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
507 }
508 test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
509 }
510
511 enum class WhatOrdersToTest { All, OnlyRCC };
512
513 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)514 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
515 int cols, WhatParamsToTest params_to_test,
516 WhatOrdersToTest orders_to_test) {
517 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \
518 do { \
519 test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
520 MapOrder::ResultOrder>(context, rows, depth, cols, \
521 params_to_test); \
522 } while (false)
523
524 if (orders_to_test == WhatOrdersToTest::All) {
525 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
526 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
527 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
528 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
529
530 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
531 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
532 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
533 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
534 } else {
535 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
536 }
537
538 #undef GEMMLOWP_ONE_TEST
539 }
540
541 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)542 void test_gemm_kernel(MultiThreadGemmContext* context) {
543 typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
544 DefaultL8R8BitDepthParams>
545 GemmWrapper;
546 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
547 WhatOrdersToTest::OnlyRCC);
548 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
549 WhatOrdersToTest::OnlyRCC);
550 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
551 WhatOrdersToTest::OnlyRCC);
552 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
553 WhatOrdersToTest::OnlyRCC);
554 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
555 WhatOrdersToTest::OnlyRCC);
556 test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
557 WhatOrdersToTest::OnlyRCC);
558 test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
559 WhatOrdersToTest::OnlyRCC);
560 test_gemm<GemmWrapper>(context, 200, 200, 200,
561 WhatParamsToTest::OnlyGenericCase,
562 WhatOrdersToTest::All);
563 test_gemm<GemmWrapper>(context, 50, 5000, 50,
564 WhatParamsToTest::OnlyGenericCase,
565 WhatOrdersToTest::OnlyRCC);
566 }
567
568 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)569 void test_gemm(typename GemmWrapper::Context* context) {
570 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
571 WhatOrdersToTest::OnlyRCC);
572 test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
573 WhatOrdersToTest::OnlyRCC);
574 test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
575 WhatOrdersToTest::OnlyRCC);
576 test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
577 WhatOrdersToTest::OnlyRCC);
578 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
579 WhatOrdersToTest::OnlyRCC);
580 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
581 WhatOrdersToTest::OnlyRCC);
582 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
583 WhatOrdersToTest::OnlyRCC);
584 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
585 WhatOrdersToTest::OnlyRCC);
586 test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
587 WhatOrdersToTest::OnlyRCC);
588 test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
589 WhatOrdersToTest::OnlyRCC);
590 test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
591 WhatOrdersToTest::OnlyRCC);
592 test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
593 WhatOrdersToTest::OnlyRCC);
594 test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
595 WhatOrdersToTest::OnlyRCC);
596 test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
597 WhatOrdersToTest::OnlyRCC);
598 test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
599 WhatOrdersToTest::OnlyRCC);
600 test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
601 WhatOrdersToTest::OnlyRCC);
602 test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
603 WhatOrdersToTest::OnlyRCC);
604
605 test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
606 WhatOrdersToTest::OnlyRCC);
607 test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
608 WhatOrdersToTest::OnlyRCC);
609 test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
610 WhatOrdersToTest::OnlyRCC);
611 test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
612 WhatOrdersToTest::OnlyRCC);
613 test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
614 WhatOrdersToTest::OnlyRCC);
615 test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
616 WhatOrdersToTest::OnlyRCC);
617
618 test_gemm<GemmWrapper>(context, 512, 512, 512,
619 WhatParamsToTest::OnlyGenericCase,
620 WhatOrdersToTest::OnlyRCC);
621 test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
622 WhatParamsToTest::OnlyGenericCase,
623 WhatOrdersToTest::OnlyRCC);
624 test_gemm<GemmWrapper>(context, 567, 2345, 123,
625 WhatParamsToTest::OnlyGenericCase,
626 WhatOrdersToTest::OnlyRCC);
627 test_gemm<GemmWrapper>(context, 100, 5000, 100,
628 WhatParamsToTest::OnlyGenericCase,
629 WhatOrdersToTest::OnlyRCC);
630 test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
631 WhatOrdersToTest::OnlyRCC);
632 test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
633 WhatOrdersToTest::OnlyRCC);
634 test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
635 WhatOrdersToTest::OnlyRCC);
636 test_gemm<GemmWrapper>(context, 1, 1000, 1000,
637 WhatParamsToTest::OnlyGenericCase,
638 WhatOrdersToTest::OnlyRCC);
639 test_gemm<GemmWrapper>(context, 1000, 1, 1000,
640 WhatParamsToTest::OnlyGenericCase,
641 WhatOrdersToTest::OnlyRCC);
642 test_gemm<GemmWrapper>(context, 1000, 1000, 1,
643 WhatParamsToTest::OnlyGenericCase,
644 WhatOrdersToTest::OnlyRCC);
645 test_gemm<GemmWrapper>(context, 777, 3456, 1,
646 WhatParamsToTest::OnlyGenericCase,
647 WhatOrdersToTest::OnlyRCC);
648 test_gemm<GemmWrapper>(context, 4567, 555, 1,
649 WhatParamsToTest::OnlyGenericCase,
650 WhatOrdersToTest::OnlyRCC);
651
652 // Test all storage orders
653 test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
654 WhatOrdersToTest::All);
655 test_gemm<GemmWrapper>(context, 300, 400, 500,
656 WhatParamsToTest::OnlyGenericCase,
657 WhatOrdersToTest::All);
658 }
659
660 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)661 void test_gemv(typename GemmWrapper::Context* context) {
662 test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
663 WhatOrdersToTest::OnlyRCC);
664 test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
665 WhatOrdersToTest::OnlyRCC);
666 test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
667 WhatOrdersToTest::OnlyRCC);
668 test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
669 WhatOrdersToTest::OnlyRCC);
670 test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
671 WhatOrdersToTest::OnlyRCC);
672 test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
673 WhatOrdersToTest::OnlyRCC);
674 test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
675 WhatOrdersToTest::OnlyRCC);
676 test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
677 WhatOrdersToTest::OnlyRCC);
678 test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
679 WhatOrdersToTest::OnlyRCC);
680 test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
681 WhatOrdersToTest::OnlyRCC);
682 test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
683 WhatOrdersToTest::OnlyRCC);
684 test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
685 WhatOrdersToTest::OnlyRCC);
686
687 // Test all storage orders
688 test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
689 WhatOrdersToTest::All);
690 test_gemm<GemmWrapper>(context, 300, 400, 1,
691 WhatParamsToTest::OnlyGenericCase,
692 WhatOrdersToTest::All);
693 }
694
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)695 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
696 switch (b) {
697 case eight_bit_int_gemm::BitDepthSetting::A8B8:
698 return "Lhs: 8 bit, Rhs: 8 bit";
699 case eight_bit_int_gemm::BitDepthSetting::A5B7:
700 return "Lhs: 7 bit, Rhs: 5 bit";
701 default:
702 abort();
703 return nullptr;
704 }
705 }
706
707 // Runs a small set of hand-picked data for per-channel quantized data.
708 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
709 // image.
TestWithSmallDataPerChannelQuantization()710 void TestWithSmallDataPerChannelQuantization() {
711 const int m = 2;
712 const int n = 9;
713 const int k = 12;
714
715 // 12 x 2, columnwise.
716 const uint8_t a_data[] = {
717 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
718 64, 64, 64, 64, 64, 64, 0, 0, 0, 255, 255, 255
719 };
720 const int lda = k;
721 int a_offset[] = {0, -64};
722 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
723 const OffsetColMap lhs_offset(a_offset, m);
724
725 // 12 x 9, columnwise.
726 const uint8_t b_data[] = {
727 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
728 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
729 0, 0, 0, 127, 127, 127, 0, 0, 0, 127, 127, 127,
730 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
731 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
732 0, 0, 0, 127, 127, 127, 0, 0, 0, 127, 127, 127,
733 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127,
734 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127,
735 0, 0, 0, 127, 127, 127, 127, 127, 127, 127, 127, 127
736 };
737 const int ldb = k;
738 int b_offset = -127;
739 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
740 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
741
742 // 2 x 9, columnwise.
743 const uint8_t expected_c_data[] = {
744 255, 255,
745 0, 0,
746 127, 159,
747 0, 64,
748 0, 64,
749 127, 159,
750 127, 127,
751 127, 127,
752 127, 127
753 };
754 const int ldc = m;
755 int c_offset[] = {97155, 97346};
756 int c_mult_int[] = {2741, 2741};
757 const int c_shift = 21;
758
759 const int c_count = m * n;
760 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
761 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
762 ldc);
763 const OffsetColMap result_offset(c_offset, m);
764 const OffsetColMap result_mult_int(c_mult_int, m);
765 const int result_shift = c_shift;
766
767 GemmContext gemm_context;
768 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
769 result_offset, result_mult_int, result_shift);
770 GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>(
771 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
772 output_pipeline);
773
774 ResultStats stats;
775 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
776
777 ResultStatsBounds bounds;
778 const bool good = CheckResultStatsBounds(stats, bounds);
779 printf("TestWithSmallDataPerChannelQuantization: %s\n",
780 good ? "PASS" : "FAIL");
781 ReportResultStats(stats, bounds);
782 Check(good);
783 }
784
785 // Runs a larger set of hand-picked data for per-channel quantized data.
786 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
787 // image. Right now, I have 7 different filters and 15 copies of the first
788 // filter to make sure NEON code path that processes 16 rows at a time is
789 // covered.
TestWithLargeDataPerChannelQuantization()790 void TestWithLargeDataPerChannelQuantization() {
791 const int m = 22;
792 const int n = 25;
793 const int k = 27;
794
795 // 27 x 22, column-wise.
796 const uint8_t a_data[] = {
797 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
798 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
799 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 255, 255, 255,
800 127, 127, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0,
801 0, 0, 0, 127, 127, 127, 0, 0, 0, 0, 0, 0, 255, 255, 255,
802 0, 0, 0, 0, 0, 0, 127, 127, 127, 0, 0, 0,
803 51, 51, 51, 51, 51, 51, 51, 51, 51, 0, 0, 0, 255, 255, 255,
804 0, 0, 0, 51, 51, 51, 51, 51, 51, 51, 51, 51,
805 51, 51, 51, 0, 0, 0, 51, 51, 51, 51, 51, 51, 255, 255, 255,
806 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51,
807 0, 0, 0, 64, 64, 64, 0, 0, 0, 64, 64, 64, 255, 255, 255,
808 64, 64, 64, 0, 0, 0, 64, 64, 64, 0, 0, 0,
809 36, 36, 36, 0, 0, 0, 36, 36, 36, 0, 0, 0, 255, 255, 255,
810 0, 0, 0, 36, 36, 36, 0, 0, 0, 36, 36, 36,
811 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
812 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
813 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
814 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
815 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
816 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
817 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
818 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
819 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
820 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
821 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
822 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
823 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
824 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
825 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
826 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
827 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
828 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
829 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
830 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
831 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
832 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
833 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
834 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
835 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
836 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
837 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
838 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
839 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
840 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
841 };
842 const int lda = k;
843 int a_offset[] = {
844 0, 0, 0, -51, -51, 0, -36, 0, 0, 0,
845 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
846 0, 0
847 };
848 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
849 const OffsetColMap lhs_offset(a_offset, m);
850
851 // 27 x 25, column-wise.
852 const uint8_t b_data[] = {
853 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119,
854 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
855 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
856 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
857 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
858 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
859 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
860 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
861 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
863 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
864 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
865 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
866 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136,
867 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
868 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119,
869 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
870 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119,
871 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
872 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
873 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
874 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
875 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
876 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119,
877 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136,
878 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
879 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119,
880 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
882 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
883 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
884 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
885 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119,
886 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
887 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119,
888 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
889 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
890 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
891 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
892 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
893 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
894 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
895 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
897 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
898 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
899 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
900 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
901 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
902 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127
903 };
904 const int ldb = k;
905 int b_offset = -127;
906 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
907 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
908
909 // 22 x 25, column-wise.
910 const uint8_t expected_c_data[] = {
911 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
912 7, 7, 7, 7, 7, 7, 7,
913 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
914 7, 7, 7, 7, 7, 7, 7,
915 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
916 7, 7, 7, 7, 7, 7, 7,
917 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
918 7, 7, 7, 7, 7, 7, 7,
919 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
920 7, 7, 7, 7, 7, 7, 7,
921 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
922 7, 7, 7, 7, 7, 7, 7,
923 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7,
924 7, 7, 7, 7, 7, 7, 7,
925 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7,
926 7, 7, 7, 7, 7, 7, 7,
927 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7,
928 7, 7, 7, 7, 7, 7, 7,
929 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
930 7, 7, 7, 7, 7, 7, 7,
931 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
932 7, 7, 7, 7, 7, 7, 7,
933 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7,
934 7, 7, 7, 7, 7, 7, 7,
935 255, 135, 135, 255, 255, 143, 255, 255, 255, 255, 255, 255, 255, 255, 255,
936 255, 255, 255, 255, 255, 255, 255,
937 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7,
938 7, 7, 7, 7, 7, 7, 7,
939 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
940 7, 7, 7, 7, 7, 7, 7,
941 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
942 7, 7, 7, 7, 7, 7, 7,
943 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7,
944 7, 7, 7, 7, 7, 7, 7,
945 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7,
946 7, 7, 7, 7, 7, 7, 7,
947 7, 7, 7, 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7,
948 7, 7, 7, 7, 7, 7, 7,
949 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
950 7, 7, 7, 7, 7, 7, 7,
951 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
952 7, 7, 7, 7, 7, 7, 7,
953 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
954 7, 7, 7, 7, 7, 7, 7,
955 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
956 7, 7, 7, 7, 7, 7, 7,
957 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
958 7, 7, 7, 7, 7, 7, 7,
959 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
960 7, 7, 7, 7, 7, 7, 7,
961 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
962 99, 99, 99, 99, 99, 99, 99,
963 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
964 111, 111, 111, 111, 111, 111, 111,
965 };
966 const int ldc = m;
967 int c_offset[] = {
968 6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477,
969 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477,
970 6477, 6477,
971 };
972 int c_mult_int[] = {
973 41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121, 41121, 41121,
974 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
975 41121, 41121,
976 };
977 const int c_shift = 21;
978
979 const int c_count = m * n;
980 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
981 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
982 ldc);
983 const OffsetColMap result_offset(c_offset, m);
984 const OffsetColMap result_mult_int(c_mult_int, m);
985 const int result_shift = c_shift;
986
987 GemmContext gemm_context;
988 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
989 result_offset, result_mult_int, result_shift);
990 GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>(
991 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
992 output_pipeline);
993
994 ResultStats stats;
995 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
996
997 ResultStatsBounds bounds;
998 const bool good = CheckResultStatsBounds(stats, bounds);
999 printf("TestWithLargeDataPerChannelQuantization: %s\n",
1000 good ? "PASS" : "FAIL");
1001 ReportResultStats(stats, bounds);
1002 Check(good);
1003 }
1004
1005 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1006 void TestWithSmallData() {
1007 const int m = 4;
1008 const int n = 2;
1009 const int k = 3;
1010 // Matrix A (LHS) is:
1011 // | 7 | 10 | 13 | 16 |
1012 // | 8 | 11 | 14 | 17 |
1013 // | 9 | 12 | 15 | 18 |
1014 const uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1015 // Matrix B (RHS) is:
1016 // | 1 | 3 | 5 |
1017 // | 2 | 4 | 6 |
1018 const uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1019 // Here are the results we expect, from hand calculations:
1020 // (1 * 7) + (3 * 8) + (5 * 9) = 76
1021 // (2 * 7) + (4 * 8) + (6 * 9) = 100
1022 // (1 * 10) + (3 * 11) + (5 * 12) = 103
1023 // (2 * 10) + (4 * 11) + (6 * 12) = 136
1024 // (1 * 13) + (3 * 14) + (5 * 15) = 130
1025 // (2 * 13) + (4 * 14) + (6 * 15) = 172
1026 // (1 * 16) + (3 * 17) + (5 * 18) = 157
1027 // (2 * 16) + (4 * 17) + (6 * 18) = 208
1028 // That means matrix C should be:
1029 // | 76 | 103 | 130 | 157 |
1030 // | 100 | 136 | 172 | 208 |
1031 const uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1032
1033 const int c_count = m * n;
1034 std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
1035
1036 const bool is_a_transposed = true;
1037 const bool is_b_transposed = true;
1038 const bool is_c_transposed = true;
1039 const int lda = k;
1040 const int ldb = n;
1041 const int ldc = n;
1042
1043 const int a_offset = 0;
1044 const int b_offset = 0;
1045 const int c_offset = 0;
1046 const int c_mult = 1;
1047 const int c_shift = 0;
1048
1049 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1050 is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1051 a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1052 c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1053
1054 ResultStats stats;
1055 GetResultStats(output_data.get(), expected_data, c_count, &stats);
1056
1057 ResultStatsBounds bounds;
1058 const bool good = CheckResultStatsBounds(stats, bounds);
1059 printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1060 ReportResultStats(stats, bounds);
1061 Check(good);
1062 }
1063
1064 // This is the most realistic test of how we'll be using the low-precision GEMM
1065 // function in applications. It takes in large input matrices that have been
1066 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1067 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1068 int tolerance_median, int tolerance_max) {
1069 std::unique_ptr<uint8_t[]> output_data(new uint8_t[test_data::c_count]);
1070 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1071 test_data::is_a_transposed, test_data::is_b_transposed,
1072 test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1073 test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1074 test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1075 test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1076
1077 ResultStats stats;
1078 GetResultStats(output_data.get(), test_data::expected_c_data,
1079 test_data::c_count, &stats);
1080
1081 ResultStatsBounds bounds;
1082 if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1083 bounds.med_unsigned_diff = tolerance_median;
1084 bounds.max_unsigned_diff = tolerance_max;
1085 bounds.med_signed_diff = 0;
1086 bounds.mean_signed_diff = 0.2f;
1087 }
1088
1089 const bool good = CheckResultStatsBounds(stats, bounds);
1090 printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1091 GetBitDepthName(BitDepth));
1092 ReportResultStats(stats, bounds);
1093 Check(good);
1094 }
1095
1096 template <MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1097 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1098 int result_mult_int, int result_shift) {
1099 Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1100 Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1101 Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1102 MakeRandom(&lhs, 8);
1103 MakeRandom(&rhs, 8);
1104 const int lhs_offset = 12;
1105 const int rhs_offset = -34;
1106
1107 // Test an empty pipeline, i.e. returning raw int32 accumulators.
1108 auto empty_pipeline = std::make_tuple();
1109 GemmContext context;
1110 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1111 &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1112 rhs_offset, empty_pipeline);
1113
1114 for (int r = 0; r < rows; r++) {
1115 for (int c = 0; c < cols; c++) {
1116 std::int32_t expected = 0;
1117 for (int d = 0; d < depth; d++) {
1118 std::int32_t lhs_val =
1119 static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1120 std::int32_t rhs_val =
1121 static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1122 expected += lhs_val * rhs_val;
1123 }
1124 Check(expected == result_raw_int32(r, c));
1125 }
1126 }
1127
1128 // Test a pipeline with only the quantize-down stage, still returning
1129 // unclamped (but scaled) int32's
1130 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1131 quantize_down_stage.result_offset = result_offset;
1132 quantize_down_stage.result_mult_int = result_mult_int;
1133 quantize_down_stage.result_shift = result_shift;
1134 auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1135 Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1136 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1137 &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1138 lhs_offset, rhs_offset, quantize_down_pipeline);
1139
1140 std::uint64_t sum = 0;
1141 for (int r = 0; r < rows; r++) {
1142 for (int c = 0; c < cols; c++) {
1143 std::int32_t raw = result_raw_int32(r, c);
1144 const std::int32_t rounding =
1145 (result_shift < 1) ? 0 : (1 << (result_shift - 1));
1146 std::int32_t expected =
1147 ((raw + result_offset) * result_mult_int + rounding) >> result_shift;
1148 Check(expected == result_quantized_down_int32(r, c));
1149 sum += expected;
1150 }
1151 }
1152 std::uint64_t avg = sum / (rows * cols);
1153 // Test that the average quantized-down value falls reasonably in the
1154 // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1155 // adjusted.
1156 Check(avg >= 64 && avg <= 192);
1157
1158 // Test the familiar default pipeline consisting of quantize-down and
1159 // clamp-and-cast-to-uint8.
1160 OutputStageSaturatingCastToUint8 saturating_cast_stage;
1161 auto quantize_down_and_saturating_cast_pipeline =
1162 std::make_tuple(quantize_down_stage, saturating_cast_stage);
1163 Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1164 cols);
1165 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1166 &context, lhs.const_map(), rhs.const_map(),
1167 &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1168 quantize_down_and_saturating_cast_pipeline);
1169
1170 for (int r = 0; r < rows; r++) {
1171 for (int c = 0; c < cols; c++) {
1172 std::int32_t quantized = result_quantized_down_int32(r, c);
1173 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1174 Check(expected == result_quantized_down_saturated_uint8(r, c));
1175 }
1176 }
1177
1178 // Test a bias-addition with row-vector
1179 std::vector<std::int32_t> row_vector_data(cols);
1180 for (int i = 0; i < cols; i++) {
1181 row_vector_data[i] = (Random() % 1000) - 500;
1182 }
1183 typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1184 RowVectorMap row_vector_map(row_vector_data.data(), cols);
1185 OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1186 row_bias_addition_stage.bias_vector = row_vector_map;
1187 auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1188 Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1189 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1190 &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1191 lhs_offset, rhs_offset, row_bias_addition_pipeline);
1192 for (int r = 0; r < rows; r++) {
1193 for (int c = 0; c < cols; c++) {
1194 std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1195 Check(expected == result_of_row_bias_addition(r, c));
1196 }
1197 }
1198
1199 // Test a bias-addition with column-vector
1200 std::vector<std::int32_t> col_vector_data(rows);
1201 for (int i = 0; i < rows; i++) {
1202 col_vector_data[i] = (Random() % 1000) - 500;
1203 }
1204 typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1205 ColVectorMap col_vector_map(col_vector_data.data(), rows);
1206 OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1207 col_bias_addition_stage.bias_vector = col_vector_map;
1208 auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1209 Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1210 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1211 &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1212 lhs_offset, rhs_offset, col_bias_addition_pipeline);
1213 for (int r = 0; r < rows; r++) {
1214 for (int c = 0; c < cols; c++) {
1215 std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1216 Check(expected == result_of_col_bias_addition(r, c));
1217 }
1218 }
1219
1220 // Test a clamp
1221 OutputStageClamp clamp_stage;
1222 // Determine min and max of raw int32 accumulators
1223 std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1224 std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1225 for (int r = 0; r < rows; r++) {
1226 for (int c = 0; c < cols; c++) {
1227 raw_min = std::min(raw_min, result_raw_int32(r, c));
1228 raw_max = std::max(raw_max, result_raw_int32(r, c));
1229 }
1230 }
1231 // Pick some interesting clamp min/max bounds
1232 clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1233 clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1234 assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1235 clamp_stage.max <= raw_max);
1236 auto clamp_pipeline = std::make_tuple(clamp_stage);
1237 Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1238 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1239 &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1240 rhs_offset, clamp_pipeline);
1241 for (int r = 0; r < rows; r++) {
1242 for (int c = 0; c < cols; c++) {
1243 std::int32_t raw = result_raw_int32(r, c);
1244 std::int32_t expected =
1245 std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1246 Check(expected == result_clamped(r, c));
1247 }
1248 }
1249
1250 // Test tanh
1251 OutputStageTanh tanh_stage;
1252 const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1253 const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1254 tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1255 tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1256 auto tanh_pipeline = std::make_tuple(tanh_stage);
1257 Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1258 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1259 &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1260 rhs_offset, tanh_pipeline);
1261 for (int r = 0; r < rows; r++) {
1262 for (int c = 0; c < cols; c++) {
1263 std::int32_t raw = result_raw_int32(r, c);
1264 double real_input =
1265 double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1266 double expected = std::tanh(real_input);
1267 std::int32_t actual_int32 = result_tanh(r, c);
1268 double actual =
1269 double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1270 Check(std::abs(expected - actual) < 2e-4);
1271 }
1272 }
1273
1274 // Test a pipeline with bias and clamp
1275 auto bias_clamp_pipeline =
1276 std::make_tuple(col_bias_addition_stage, clamp_stage);
1277 Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1278 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1279 &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1280 lhs_offset, rhs_offset, bias_clamp_pipeline);
1281 for (int r = 0; r < rows; r++) {
1282 for (int c = 0; c < cols; c++) {
1283 std::int32_t raw = result_raw_int32(r, c);
1284 std::int32_t biased = raw + col_vector_data[r];
1285 std::int32_t expected =
1286 std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1287 Check(expected == result_biased_clamped(r, c));
1288 }
1289 }
1290
1291 // Test a full pipeline with bias and clamp and quantization down to 8bit
1292 // result
1293 auto bias_clamp_quantize_cast_pipeline =
1294 std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1295 saturating_cast_stage);
1296 Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1297 rows, cols);
1298 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1299 &context, lhs.const_map(), rhs.const_map(),
1300 &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1301 bias_clamp_quantize_cast_pipeline);
1302 for (int r = 0; r < rows; r++) {
1303 for (int c = 0; c < cols; c++) {
1304 const std::int32_t rounding =
1305 (result_shift < 1) ? 0 : (1 << (result_shift - 1));
1306 std::int32_t quantized =
1307 ((result_biased_clamped(r, c) + result_offset) * result_mult_int +
1308 rounding) >>
1309 result_shift;
1310 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1311 Check(expected == result_biased_clamped_quantized_casted(r, c));
1312 }
1313 }
1314
1315 printf("TestOutputStages: PASS with ResultOrder=%s\n",
1316 OrderName(ResultOrder));
1317 }
1318
1319 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
TestExhaustively()1320 void TestExhaustively() {
1321 GemmContext context;
1322
1323 // Test the internal GEMM interfaces
1324 test_gemm<SingleThreadGemmWrapper<
1325 DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>,
1326 std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1327
1328 test_gemm<MultiThreadGemmWrapper<
1329 DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>,
1330 std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1331
1332 // Test the public GEMM interfaces
1333 test_gemm<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context);
1334
1335 test_gemm<EightBitIntGemmWrapper<uint8_t,
1336 eight_bit_int_gemm::BitDepthSetting::A8B8>>(
1337 &context);
1338
1339 // Test GEMV cases (internal interfaces)
1340 test_gemv<SingleThreadGemmWrapper<
1341 DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>,
1342 std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1343
1344 test_gemv<MultiThreadGemmWrapper<
1345 DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>,
1346 std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1347
1348 // Test GEMV cases (public interfaces)
1349 test_gemv<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context);
1350
1351 test_gemv<EightBitIntGemmWrapper<uint8_t,
1352 eight_bit_int_gemm::BitDepthSetting::A8B8>>(
1353 &context);
1354
1355 // Test other bit depths
1356 // L7R5
1357 test_gemm<SingleThreadGemmWrapper<
1358 DefaultKernel<KernelFamily::Gemm, DefaultL7R5BitDepthParams>,
1359 std::uint8_t, DefaultL7R5BitDepthParams>>(&context);
1360
1361 test_gemv<SingleThreadGemmWrapper<
1362 DefaultKernel<KernelFamily::Gemv, DefaultL7R5BitDepthParams>,
1363 std::uint8_t, DefaultL7R5BitDepthParams>>(&context);
1364
1365 test_gemm<EightBitIntGemmWrapper<std::uint8_t,
1366 eight_bit_int_gemm::BitDepthSetting::A5B7>>(
1367 &context);
1368
1369 // Test specific kernels with various different formats,
1370 // to exercises corner cases especially in the packing code.
1371 test_gemm_kernel<
1372 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1373 KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1374 &context);
1375
1376 test_gemm_kernel<
1377 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1378 KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1379 &context);
1380
1381 test_gemm_kernel<
1382 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1383 KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1384 &context);
1385
1386 test_gemm_kernel<ReferenceKernel<KernelFormat<
1387 KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1388 KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1389
1390 test_gemm_kernel<ReferenceKernel<KernelFormat<
1391 KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1392 KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1393
1394 test_gemm_kernel<ReferenceKernel<KernelFormat<
1395 KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1396 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1397
1398 test_gemm_kernel<ReferenceKernel<KernelFormat<
1399 KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1400 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1401
1402 test_gemm_kernel<ReferenceKernel<KernelFormat<
1403 KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1404 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1405
1406 test_gemm_kernel<ReferenceKernel<KernelFormat<
1407 KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1408 KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1409 }
1410 #endif // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1411
test()1412 void test() {
1413 #ifdef GEMMLOWP_TEST_PROFILE
1414 RegisterCurrentThreadForProfiling();
1415 StartProfiling();
1416 #endif
1417
1418 // Run a first quick test against hand-calculated data.
1419 TestWithSmallData();
1420
1421 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1422 TestExhaustively();
1423 #endif
1424
1425 // Run against actual data from a network evaluation.
1426 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1427 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1428
1429 // Test non-default output pipelines with various combinations of
1430 // output stages.
1431 TestOutputStages<MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1432 TestOutputStages<MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1433
1434 // Test per channel quantization.
1435 TestWithSmallDataPerChannelQuantization();
1436 TestWithLargeDataPerChannelQuantization();
1437 #ifdef GEMMLOWP_TEST_PROFILE
1438 FinishProfiling();
1439 #endif
1440
1441 std::cerr << "All tests passed." << std::endl;
1442
1443 // We have been testing the eight_bit_int_gemm, so we should free its
1444 // persistent
1445 // resources now to avoid having leak-checking tools report leaks.
1446 eight_bit_int_gemm::FreePersistentResources();
1447 }
1448
1449 } // end namespace gemmlowp
1450
1451 // For iOS, we need to define our own main(), so skip it here.
1452 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1453 int main() { gemmlowp::test(); }
1454 #endif
1455