• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 
17 #include <unistd.h>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28 
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32 
33 namespace gemmlowp {
34 
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const uint8_t * a,int32_t a_offset,int lda,const uint8_t * b,int32_t b_offset,int ldb,uint8_t * c,int32_t c_offset,int32_t c_mult_int,int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36                               bool transpose_c, int m, int n, int k,
37                               const uint8_t* a, int32_t a_offset, int lda,
38                               const uint8_t* b, int32_t b_offset, int ldb,
39                               uint8_t* c, int32_t c_offset, int32_t c_mult_int,
40                               int32_t c_shift, int ldc) {
41   assert((c_shift >= 0) && (c_shift <= 32));
42 
43   assert(a != nullptr);
44   assert(b != nullptr);
45   assert(c != nullptr);
46 
47   int a_i_stride;
48   int a_l_stride;
49   if (transpose_a) {
50     a_i_stride = lda;
51     a_l_stride = 1;
52   } else {
53     a_i_stride = 1;
54     a_l_stride = lda;
55   }
56   int b_j_stride;
57   int b_l_stride;
58   if (transpose_b) {
59     b_j_stride = 1;
60     b_l_stride = ldb;
61   } else {
62     b_j_stride = ldb;
63     b_l_stride = 1;
64   }
65   int c_i_stride;
66   int c_j_stride;
67   if (transpose_c) {
68     c_i_stride = ldc;
69     c_j_stride = 1;
70   } else {
71     c_i_stride = 1;
72     c_j_stride = ldc;
73   }
74   int i, j, l;
75 
76   const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
77 
78   for (j = 0; j < n; j++) {
79     for (i = 0; i < m; i++) {
80       int32_t total = 0;
81       for (l = 0; l < k; l++) {
82         const int a_index = i * a_i_stride + l * a_l_stride;
83         const uint8_t a_as_byte = a[a_index];
84         const int32_t a_as_int = static_cast<int32_t>(a_as_byte) + a_offset;
85         const int b_index = j * b_j_stride + l * b_l_stride;
86         const uint8_t b_as_byte = b[b_index];
87         const int32_t b_as_int = static_cast<int32_t>(b_as_byte) + b_offset;
88         const int32_t mult_as_int = a_as_int * b_as_int;
89         total += mult_as_int;
90       }
91       int32_t output =
92           (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
93       if (output > 255) {
94         output = 255;
95       }
96       if (output < 0) {
97         output = 0;
98       }
99       const int c_index = i * c_i_stride + j * c_j_stride;
100       c[c_index] = static_cast<uint8_t>(output);
101     }
102   }
103 }
104 
105 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
106 // interface, so we can use the same testing code to test all of them
107 
108 template <typename Kernel, typename Scalar, typename tBitDepthParams>
109 struct SingleThreadGemmWrapper {
110   typedef tBitDepthParams BitDepthParams;
111 
Namegemmlowp::SingleThreadGemmWrapper112   static const char* Name() {
113     static char buf[256];
114     snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
115     return buf;
116   }
117 
118   typedef SingleThreadGemmContext Context;
119 
120   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper121   static void Gemm(Context* context,
122                    const MatrixMap<const Scalar, LhsOrder>& lhs,
123                    const MatrixMap<const Scalar, RhsOrder>& rhs,
124                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
125                    int rhs_offset, int result_offset, int result_mult_int,
126                    int result_shift) {
127     const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
128     const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
129     SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
130                      LhsOrder, RhsOrder, ResultOrder,
131                      OffsetColDup, OffsetRowDup>(
132         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
133         rhs_offset_vector,
134         MakeStandardOutputPipeline(result_offset, result_mult_int,
135                                    result_shift));
136   }
137 };
138 
139 template <typename Kernel, typename Scalar, typename tBitDepthParams>
140 struct MultiThreadGemmWrapper {
141   typedef tBitDepthParams BitDepthParams;
142 
Namegemmlowp::MultiThreadGemmWrapper143   static const char* Name() {
144     static char buf[256];
145     snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
146     return buf;
147   }
148 
149   typedef MultiThreadGemmContext Context;
150 
151   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper152   static void Gemm(Context* context,
153                    const MatrixMap<const Scalar, LhsOrder>& lhs,
154                    const MatrixMap<const Scalar, RhsOrder>& rhs,
155                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
156                    int rhs_offset, int result_offset, int result_mult_int,
157                    int result_shift) {
158     const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
159     const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
160     MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
161                     LhsOrder, RhsOrder, ResultOrder,
162                     OffsetColDup, OffsetRowDup>(
163         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
164         rhs_offset_vector,
165         MakeStandardOutputPipeline(result_offset, result_mult_int,
166                                    result_shift));
167   }
168 };
169 
170 template <typename Scalar, typename tBitDepthParams>
171 struct PublicGemmWrapper {
172   typedef tBitDepthParams BitDepthParams;
173 
Namegemmlowp::PublicGemmWrapper174   static const char* Name() { return "public Gemm"; }
175 
176   typedef GemmContext Context;
177 
178   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper179   static void Gemm(Context* context,
180                    const MatrixMap<const Scalar, LhsOrder>& lhs,
181                    const MatrixMap<const Scalar, RhsOrder>& rhs,
182                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
183                    int rhs_offset, int result_offset, int result_mult_int,
184                    int result_shift) {
185     gemmlowp::Gemm<uint8_t, BitDepthParams, LhsOrder, RhsOrder, ResultOrder>(
186         context, lhs, rhs, result, lhs_offset, rhs_offset, result_offset,
187         result_mult_int, result_shift);
188   }
189 };
190 
191 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
192 struct BitDepthParamsForSettings {};
193 
194 template <>
195 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
196     : DefaultL8R8BitDepthParams {};
197 
198 template <>
199 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
200     : DefaultL7R5BitDepthParams {};
201 
202 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
203 struct EightBitIntGemmWrapper {
204   typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
205 
Namegemmlowp::EightBitIntGemmWrapper206   static const char* Name() { return "EightBitIntGemm"; }
207 
208   typedef void Context;
209 
210   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper211   static void Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
212                    const MatrixMap<const Scalar, RhsOrder>& rhs,
213                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
214                    int rhs_offset, int result_offset, int result_mult_int,
215                    int result_shift) {
216     const bool transpose_c = ResultOrder == MapOrder::RowMajor;
217     const bool transpose_a = LhsOrder == MapOrder::RowMajor;
218     const bool transpose_b = RhsOrder == MapOrder::RowMajor;
219     eight_bit_int_gemm::EightBitIntGemm(
220         transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
221         lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
222         rhs_offset, rhs.stride(), result->data(), result_offset,
223         result_mult_int, result_shift, result->stride(), BitDepth);
224   }
225 };
226 
227 template <typename Scalar>
228 struct ReferenceEightBitIntGemmWrapper {
229   typedef DefaultL8R8BitDepthParams BitDepthParams;
230 
Namegemmlowp::ReferenceEightBitIntGemmWrapper231   static const char* Name() { return "ReferenceEightBitIntGemm"; }
232 
233   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper234   static void Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
235                    const MatrixMap<const Scalar, LhsOrder>& lhs,
236                    const MatrixMap<const Scalar, RhsOrder>& rhs,
237                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
238                    int rhs_offset, int result_offset, int result_mult_int,
239                    int result_shift) {
240     ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
241                              rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
242                              lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
243                              result->data(), result_offset, result_mult_int,
244                              result_shift, result->stride());
245   }
246 };
247 
OrderName(MapOrder order)248 const char* OrderName(MapOrder order) {
249   return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
250 }
251 
252 struct ResultStats {
ResultStatsgemmlowp::ResultStats253   ResultStats()
254       : count(0),
255         med_val(0),
256         mean_signed_diff(0),
257         med_signed_diff(0),
258         med_unsigned_diff(0),
259         max_unsigned_diff(0) {}
260 
261   int count;
262   int med_val;
263   float mean_signed_diff;
264   int med_signed_diff;
265   int med_unsigned_diff;
266   int max_unsigned_diff;
267 
268   std::vector<int> count_diff_by_pot_slice;
269 };
270 
GetResultStats(const uint8_t * actual,const uint8_t * expected,size_t count,ResultStats * stats)271 void GetResultStats(const uint8_t* actual, const uint8_t* expected,
272                     size_t count, ResultStats* stats) {
273   std::vector<uint8_t> results;
274   std::vector<int16_t> signed_diffs;
275   std::vector<uint8_t> unsigned_diffs;
276   int64_t signed_diffs_sum = 0;
277   for (size_t i = 0; i < count; i++) {
278     results.push_back(actual[i]);
279     int16_t signed_diff = actual[i] - expected[i];
280     signed_diffs.push_back(signed_diff);
281     unsigned_diffs.push_back(std::abs(signed_diff));
282     signed_diffs_sum += signed_diff;
283   }
284 
285   std::sort(results.begin(), results.end());
286   std::sort(signed_diffs.begin(), signed_diffs.end());
287   std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
288 
289   const size_t middle = count / 2;
290 
291   stats->count = count;
292   stats->med_val = results[middle];
293   stats->mean_signed_diff = float(signed_diffs_sum) / count;
294   stats->med_signed_diff = signed_diffs[middle];
295   stats->med_unsigned_diff = unsigned_diffs[middle];
296   stats->max_unsigned_diff = unsigned_diffs.back();
297 
298   // Size 9 for 9 different POT values: 2^0, ..., 2^8
299   stats->count_diff_by_pot_slice.resize(9);
300   auto cur = unsigned_diffs.begin();
301   size_t checksum = 0;
302   for (int exponent = 0; exponent < 9; exponent++) {
303     int pot = 1 << exponent;
304     auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
305     checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
306     cur = next;
307   }
308   assert(checksum == count);
309 }
310 
311 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds312   ResultStatsBounds()
313       : mean_signed_diff(0),
314         med_signed_diff(0),
315         med_unsigned_diff(0),
316         max_unsigned_diff(0) {}
317 
318   float mean_signed_diff;
319   int med_signed_diff;
320   int med_unsigned_diff;
321   int max_unsigned_diff;
322 };
323 
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)324 bool CheckResultStatsBounds(const ResultStats& stats,
325                             const ResultStatsBounds& bounds) {
326   return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
327          stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
328          std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
329          std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
330 }
331 
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)332 void ReportResultStats(const ResultStats& stats,
333                        const ResultStatsBounds& bounds) {
334   printf("    number of matrix entries: %d\n", stats.count);
335   printf("    median value: %d\n", stats.med_val);
336   printf("    median unsigned diff: %d (tolerating %d)\n",
337          stats.med_unsigned_diff, bounds.med_unsigned_diff);
338   printf("    max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
339          bounds.max_unsigned_diff);
340   printf("    median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
341          bounds.med_signed_diff);
342   printf("    mean signed diff: %.3g (tolerating %.3g)\n",
343          stats.mean_signed_diff, bounds.mean_signed_diff);
344 
345   printf("No error: %.2f %% of entries\n",
346          100.f * stats.count_diff_by_pot_slice[0] / stats.count);
347   for (int exponent = 1; exponent < 9; exponent++) {
348     printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
349            (1 << exponent) - 1,
350            100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
351   }
352 }
353 
354 // Our approach to choosing result_shift values for testing, is bisection.
355 // This function takes an interval, [result_shift_min .. result_shift_max].
356 // If too much saturation occurred in either direction, it bisects accordingly,
357 // recursing until the interval contains only one value.
358 // The primary reason why we prefer this over computing optimal shift values,
359 // is that we actually want to exercise some saturation, as there is nontrivial
360 // code handling that in gemmlowp.
361 // Secondarily, this is faster than computing optimal shifts, since in 90% of
362 // cases the first-tried shift value 16 turns out to be good enough.
363 template <typename GemmWrapper, typename LhsType, typename RhsType,
364           typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)365 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
366                     const RhsType& rhs, ResultType* result, int lhs_offset,
367                     int rhs_offset, int result_offset, int result_mult_int,
368                     int result_shift_min, int result_shift_max) {
369   const int rows = lhs.rows();
370   const int cols = rhs.cols();
371   Check(lhs.cols() == rhs.rows());
372   const int depth = lhs.cols();
373 
374   const int result_shift = (result_shift_min + result_shift_max) / 2;
375 
376   GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(), &result->map(),
377                     lhs_offset, rhs_offset, result_offset, result_mult_int,
378                     result_shift);
379 
380   typedef typename ResultType::Scalar Scalar;
381   static const MapOrder kLhsOrder = LhsType::kOrder;
382   static const MapOrder kRhsOrder = RhsType::kOrder;
383   static const MapOrder kResultOrder = ResultType::kOrder;
384   ResultType ref_result(rows, cols);
385   const bool transpose_c = kResultOrder == MapOrder::RowMajor;
386   const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
387   const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
388   ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
389       transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
390       &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
391       result_shift);
392 
393   typedef typename GemmWrapper::BitDepthParams BitDepthParams;
394 
395   ResultStats stats;
396   GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
397 
398   // Adjust shifts until we get meaningful results
399   int new_result_shift_min = result_shift_min;
400   int new_result_shift_max = result_shift_max;
401   bool retry = false;
402 
403   if (stats.med_val < 32) {
404     new_result_shift_max = (result_shift_min + result_shift_max) / 2;
405     retry = true;
406   }
407 
408   if (stats.med_val > 224) {
409     new_result_shift_min = (result_shift_min + result_shift_max) / 2;
410     retry = true;
411   }
412 
413   if (retry) {
414     if (result_shift_min != result_shift_max) {
415       test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
416                                   rhs_offset, result_offset, result_mult_int,
417                                   new_result_shift_min, new_result_shift_max);
418     }
419     return;
420   }
421 
422   ResultStatsBounds bounds;
423 
424   if (BitDepthParams::LhsBitDepth::kBits < 8 ||
425       BitDepthParams::RhsBitDepth::kBits < 8) {
426     // We have very lax requirements on unsigned diff.
427     // We have tighter requirements on signed diff (bias), but only
428     // if the matrix is large enough for things to average out.
429     // For very small sizes, we... basically don't test anything.
430     // The problem is that this test uses unrealistic combinations of
431     // result_mult_int
432     // and result_shift, resulting in potentially wild requantization artifacts
433     // on small GEMMs.
434     int adjust_for_small_sizes = 1000 / (rows * cols);
435     bounds.max_unsigned_diff =
436         std::max(stats.med_val / 2, adjust_for_small_sizes);
437     bounds.med_unsigned_diff =
438         std::max(stats.med_val / 8, adjust_for_small_sizes);
439     bounds.med_signed_diff = std::max(2, adjust_for_small_sizes);
440     bounds.mean_signed_diff = std::max(2, adjust_for_small_sizes);
441   }
442 
443   // Check results
444   const bool good = CheckResultStatsBounds(stats, bounds);
445 
446   printf(
447       "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
448       good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
449       OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
450       lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
451 
452   if (!good) {
453     ReportResultStats(stats, bounds);
454 
455     int bad_coeffs_printed = 0;
456     for (int c = 0; c < result->cols() && bad_coeffs_printed < 20; c++) {
457       for (int r = 0; r < result->rows() && bad_coeffs_printed < 20; r++) {
458         if (ref_result(r, c) != (*result)(r, c)) {
459           printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
460                  ref_result(r, c), (*result)(r, c));
461           bad_coeffs_printed++;
462         }
463       }
464     }
465   }
466 
467   Check(good);
468 }
469 
470 template <typename GemmWrapper, typename LhsType, typename RhsType,
471           typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)472 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
473                const RhsType& rhs, ResultType* result, int lhs_offset,
474                int rhs_offset, int result_offset, int result_mult_int) {
475   test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
476                               result_offset, result_mult_int, 0, 32);
477 }
478 
479 enum class WhatParamsToTest {
480   All,
481   OnlyGenericCase,
482 };
483 
484 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
485           MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)486 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
487                int cols, WhatParamsToTest params_to_test) {
488   typedef std::uint8_t Scalar;
489   typedef Matrix<Scalar, LhsOrder> LhsType;
490   LhsType lhs(rows, depth);
491   MakeRandom(&lhs, 8);
492   typedef Matrix<Scalar, RhsOrder> RhsType;
493   RhsType rhs(depth, cols);
494   MakeRandom(&rhs, 8);
495   typedef Matrix<Scalar, ResultOrder> ResultType;
496   ResultType result(rows, cols);
497   MakeZero(&result);
498 
499   if (params_to_test == WhatParamsToTest::All) {
500     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
501     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
502     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
503     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
504     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
505     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
506     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
507   }
508   test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
509 }
510 
511 enum class WhatOrdersToTest { All, OnlyRCC };
512 
513 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)514 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
515                int cols, WhatParamsToTest params_to_test,
516                WhatOrdersToTest orders_to_test) {
517 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder)         \
518   do {                                                             \
519     test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
520               MapOrder::ResultOrder>(context, rows, depth, cols,   \
521                                      params_to_test);              \
522   } while (false)
523 
524   if (orders_to_test == WhatOrdersToTest::All) {
525     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
526     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
527     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
528     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
529 
530     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
531     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
532     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
533     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
534   } else {
535     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
536   }
537 
538 #undef GEMMLOWP_ONE_TEST
539 }
540 
541 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)542 void test_gemm_kernel(MultiThreadGemmContext* context) {
543   typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
544                                  DefaultL8R8BitDepthParams>
545       GemmWrapper;
546   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
547                          WhatOrdersToTest::OnlyRCC);
548   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
549                          WhatOrdersToTest::OnlyRCC);
550   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
551                          WhatOrdersToTest::OnlyRCC);
552   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
553                          WhatOrdersToTest::OnlyRCC);
554   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
555                          WhatOrdersToTest::OnlyRCC);
556   test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
557                          WhatOrdersToTest::OnlyRCC);
558   test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
559                          WhatOrdersToTest::OnlyRCC);
560   test_gemm<GemmWrapper>(context, 200, 200, 200,
561                          WhatParamsToTest::OnlyGenericCase,
562                          WhatOrdersToTest::All);
563   test_gemm<GemmWrapper>(context, 50, 5000, 50,
564                          WhatParamsToTest::OnlyGenericCase,
565                          WhatOrdersToTest::OnlyRCC);
566 }
567 
568 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)569 void test_gemm(typename GemmWrapper::Context* context) {
570   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
571                          WhatOrdersToTest::OnlyRCC);
572   test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
573                          WhatOrdersToTest::OnlyRCC);
574   test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
575                          WhatOrdersToTest::OnlyRCC);
576   test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
577                          WhatOrdersToTest::OnlyRCC);
578   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
579                          WhatOrdersToTest::OnlyRCC);
580   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
581                          WhatOrdersToTest::OnlyRCC);
582   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
583                          WhatOrdersToTest::OnlyRCC);
584   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
585                          WhatOrdersToTest::OnlyRCC);
586   test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
587                          WhatOrdersToTest::OnlyRCC);
588   test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
589                          WhatOrdersToTest::OnlyRCC);
590   test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
591                          WhatOrdersToTest::OnlyRCC);
592   test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
593                          WhatOrdersToTest::OnlyRCC);
594   test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
595                          WhatOrdersToTest::OnlyRCC);
596   test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
597                          WhatOrdersToTest::OnlyRCC);
598   test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
599                          WhatOrdersToTest::OnlyRCC);
600   test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
601                          WhatOrdersToTest::OnlyRCC);
602   test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
603                          WhatOrdersToTest::OnlyRCC);
604 
605   test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
606                          WhatOrdersToTest::OnlyRCC);
607   test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
608                          WhatOrdersToTest::OnlyRCC);
609   test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
610                          WhatOrdersToTest::OnlyRCC);
611   test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
612                          WhatOrdersToTest::OnlyRCC);
613   test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
614                          WhatOrdersToTest::OnlyRCC);
615   test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
616                          WhatOrdersToTest::OnlyRCC);
617 
618   test_gemm<GemmWrapper>(context, 512, 512, 512,
619                          WhatParamsToTest::OnlyGenericCase,
620                          WhatOrdersToTest::OnlyRCC);
621   test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
622                          WhatParamsToTest::OnlyGenericCase,
623                          WhatOrdersToTest::OnlyRCC);
624   test_gemm<GemmWrapper>(context, 567, 2345, 123,
625                          WhatParamsToTest::OnlyGenericCase,
626                          WhatOrdersToTest::OnlyRCC);
627   test_gemm<GemmWrapper>(context, 100, 5000, 100,
628                          WhatParamsToTest::OnlyGenericCase,
629                          WhatOrdersToTest::OnlyRCC);
630   test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
631                          WhatOrdersToTest::OnlyRCC);
632   test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
633                          WhatOrdersToTest::OnlyRCC);
634   test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
635                          WhatOrdersToTest::OnlyRCC);
636   test_gemm<GemmWrapper>(context, 1, 1000, 1000,
637                          WhatParamsToTest::OnlyGenericCase,
638                          WhatOrdersToTest::OnlyRCC);
639   test_gemm<GemmWrapper>(context, 1000, 1, 1000,
640                          WhatParamsToTest::OnlyGenericCase,
641                          WhatOrdersToTest::OnlyRCC);
642   test_gemm<GemmWrapper>(context, 1000, 1000, 1,
643                          WhatParamsToTest::OnlyGenericCase,
644                          WhatOrdersToTest::OnlyRCC);
645   test_gemm<GemmWrapper>(context, 777, 3456, 1,
646                          WhatParamsToTest::OnlyGenericCase,
647                          WhatOrdersToTest::OnlyRCC);
648   test_gemm<GemmWrapper>(context, 4567, 555, 1,
649                          WhatParamsToTest::OnlyGenericCase,
650                          WhatOrdersToTest::OnlyRCC);
651 
652   // Test all storage orders
653   test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
654                          WhatOrdersToTest::All);
655   test_gemm<GemmWrapper>(context, 300, 400, 500,
656                          WhatParamsToTest::OnlyGenericCase,
657                          WhatOrdersToTest::All);
658 }
659 
660 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)661 void test_gemv(typename GemmWrapper::Context* context) {
662   test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
663                          WhatOrdersToTest::OnlyRCC);
664   test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
665                          WhatOrdersToTest::OnlyRCC);
666   test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
667                          WhatOrdersToTest::OnlyRCC);
668   test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
669                          WhatOrdersToTest::OnlyRCC);
670   test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
671                          WhatOrdersToTest::OnlyRCC);
672   test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
673                          WhatOrdersToTest::OnlyRCC);
674   test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
675                          WhatOrdersToTest::OnlyRCC);
676   test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
677                          WhatOrdersToTest::OnlyRCC);
678   test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
679                          WhatOrdersToTest::OnlyRCC);
680   test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
681                          WhatOrdersToTest::OnlyRCC);
682   test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
683                          WhatOrdersToTest::OnlyRCC);
684   test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
685                          WhatOrdersToTest::OnlyRCC);
686 
687   // Test all storage orders
688   test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
689                          WhatOrdersToTest::All);
690   test_gemm<GemmWrapper>(context, 300, 400, 1,
691                          WhatParamsToTest::OnlyGenericCase,
692                          WhatOrdersToTest::All);
693 }
694 
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)695 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
696   switch (b) {
697     case eight_bit_int_gemm::BitDepthSetting::A8B8:
698       return "Lhs: 8 bit, Rhs: 8 bit";
699     case eight_bit_int_gemm::BitDepthSetting::A5B7:
700       return "Lhs: 7 bit, Rhs: 5 bit";
701     default:
702       abort();
703       return nullptr;
704   }
705 }
706 
707 // Runs a small set of hand-picked data for per-channel quantized data.
708 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
709 // image.
TestWithSmallDataPerChannelQuantization()710 void TestWithSmallDataPerChannelQuantization() {
711   const int m = 2;
712   const int n = 9;
713   const int k = 12;
714 
715   // 12 x 2, columnwise.
716   const uint8_t a_data[] = {
717      0,  0,  0,  0,  0,  0, 0, 0, 0, 255, 255, 255,
718     64, 64, 64, 64, 64, 64, 0, 0, 0, 255, 255, 255
719   };
720   const int lda = k;
721   int a_offset[] = {0, -64};
722   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
723   const OffsetColMap lhs_offset(a_offset, m);
724 
725   // 12 x 9, columnwise.
726   const uint8_t b_data[] = {
727       0,   0,   0,   0,   0,   0,   0,   0,   0, 255, 255, 255,
728       0,   0,   0,   0,   0,   0, 255, 255, 255,   0,   0,   0,
729       0,   0,   0, 127, 127, 127,   0,   0,   0, 127, 127, 127,
730       0,   0,   0, 255, 255, 255,   0,   0,   0,   0,   0,   0,
731     255, 255, 255,   0,   0,   0,   0,   0,   0,   0,   0,   0,
732       0,   0,   0, 127, 127, 127,   0,   0,   0, 127, 127, 127,
733       0,   0,   0,   0,   0,   0, 127, 127, 127, 127, 127, 127,
734       0,   0,   0,   0,   0,   0, 127, 127, 127, 127, 127, 127,
735       0,   0,   0, 127, 127, 127, 127, 127, 127, 127, 127, 127
736   };
737   const int ldb = k;
738   int b_offset = -127;
739   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
740   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
741 
742   // 2 x 9, columnwise.
743   const uint8_t expected_c_data[] = {
744     255, 255,
745       0,   0,
746     127, 159,
747       0,  64,
748       0,  64,
749     127, 159,
750     127, 127,
751     127, 127,
752     127, 127
753   };
754   const int ldc = m;
755   int c_offset[] = {97155, 97346};
756   int c_mult_int[] = {2741, 2741};
757   const int c_shift = 21;
758 
759   const int c_count = m * n;
760   std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
761   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
762                                                      ldc);
763   const OffsetColMap result_offset(c_offset, m);
764   const OffsetColMap result_mult_int(c_mult_int, m);
765   const int result_shift = c_shift;
766 
767   GemmContext gemm_context;
768   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
769       result_offset, result_mult_int, result_shift);
770   GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>(
771       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
772       output_pipeline);
773 
774   ResultStats stats;
775   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
776 
777   ResultStatsBounds bounds;
778   const bool good = CheckResultStatsBounds(stats, bounds);
779   printf("TestWithSmallDataPerChannelQuantization: %s\n",
780          good ? "PASS" : "FAIL");
781   ReportResultStats(stats, bounds);
782   Check(good);
783 }
784 
785 // Runs a larger set of hand-picked data for per-channel quantized data.
786 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
787 // image.  Right now, I have 7 different filters and 15 copies of the first
788 // filter to make sure NEON code path that processes 16 rows at a time is
789 // covered.
TestWithLargeDataPerChannelQuantization()790 void TestWithLargeDataPerChannelQuantization() {
791   const int m = 22;
792   const int n = 25;
793   const int k = 27;
794 
795   // 27 x 22, column-wise.
796   const uint8_t a_data[] = {
797      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
798          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
799      0,  0,  0,   0,   0,   0,  0,  0,  0, 127, 127, 127, 255, 255, 255,
800        127, 127, 127,  0,  0,  0,   0,   0,   0,  0,  0,  0,
801      0,  0,  0, 127, 127, 127,  0,  0,  0,   0,   0,   0, 255, 255, 255,
802          0,   0,   0,  0,  0,  0, 127, 127, 127,  0,  0,  0,
803     51, 51, 51,  51,  51,  51, 51, 51, 51,   0,   0,   0, 255, 255, 255,
804          0,   0,   0, 51, 51, 51,  51,  51,  51, 51, 51, 51,
805     51, 51, 51,   0,   0,   0, 51, 51, 51,  51,  51,  51, 255, 255, 255,
806         51,  51,  51, 51, 51, 51,   0,   0,   0, 51, 51, 51,
807      0,  0,  0,  64,  64,  64,  0,  0,  0,  64,  64,  64, 255, 255, 255,
808         64,  64,  64,  0,  0,  0,  64,  64,  64,  0,  0,  0,
809     36, 36, 36,   0,   0,   0, 36, 36, 36,   0,   0,   0, 255, 255, 255,
810          0,   0,   0, 36, 36, 36,   0,   0,   0, 36, 36, 36,
811      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
812          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
813      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
814          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
815      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
816          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
817      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
818          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
819      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
820          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
821      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
822          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
823      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
824          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
825      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
826          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
827      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
828          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
829      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
830          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
831      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
832          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
833      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
834          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
835      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
836          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
837      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
838          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
839      0,  0,  0,   0,   0,   0,  0,  0,  0,   0,   0,   0, 255, 255, 255,
840          0,   0,   0,  0,  0,  0,   0,   0,   0,  0,  0,  0,
841   };
842   const int lda = k;
843   int a_offset[] = {
844       0, 0, 0, -51, -51, 0, -36, 0, 0, 0,
845       0, 0, 0,   0,   0, 0,   0, 0, 0, 0,
846       0, 0
847   };
848   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
849   const OffsetColMap lhs_offset(a_offset, m);
850 
851   // 27 x 25, column-wise.
852   const uint8_t b_data[] = {
853     127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119,
854          119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
855     127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
856          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
857     127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
858          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
859     127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
860          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
861     127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862          127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
863     127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
864          119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
865     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
866          119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136,
867     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
868          119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119,
869     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
870          119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119,
871     119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
872          127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
873     127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
874          119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
875     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
876          136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119,
877     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136,
878          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
879     119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119,
880          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881     119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
882          127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
883     127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
884          119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
885     119, 119, 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119,
886          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
887     119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119,
888          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
889     136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
890          119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
891     119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
892          127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
893     127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
894          119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
895     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896          119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
897     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
898          119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
899     119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
900          119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
901     119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119,
902          127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127
903   };
904   const int ldb = k;
905   int b_offset = -127;
906   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
907   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
908 
909   // 22 x 25, column-wise.
910   const uint8_t expected_c_data[] = {
911       7,  37,  37,  67,  67,  39,  79,   7,   7,   7,   7,   7,   7,   7,   7,
912            7,   7,   7,   7,   7,   7,   7,
913       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
914            7,   7,   7,   7,   7,   7,   7,
915       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
916            7,   7,   7,   7,   7,   7,   7,
917       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
918            7,   7,   7,   7,   7,   7,   7,
919       7,  37,  37,  67,  67,  39,  79,   7,   7,   7,   7,   7,   7,   7,   7,
920            7,   7,   7,   7,   7,   7,   7,
921       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
922            7,   7,   7,   7,   7,   7,   7,
923       7,   7,   7,  87,  87,   7, 103,   7,   7,   7,   7,   7,   7,   7,   7,
924            7,   7,   7,   7,   7,   7,   7,
925       7,   7,  71,  87,  45,  41,  77,   7,   7,   7,   7,   7,   7,   7,   7,
926            7,   7,   7,   7,   7,   7,   7,
927       7,   7,   7,  87,  87,   7, 103,   7,   7,   7,   7,   7,   7,   7,   7,
928            7,   7,   7,   7,   7,   7,   7,
929       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
930            7,   7,   7,   7,   7,   7,   7,
931       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
932            7,   7,   7,   7,   7,   7,   7,
933       7,  71,   7,  45,  87,  41,  77,   7,   7,   7,   7,   7,   7,   7,   7,
934            7,   7,   7,   7,   7,   7,   7,
935     255, 135, 135, 255, 255, 143, 255, 255, 255, 255, 255, 255, 255, 255, 255,
936          255, 255, 255, 255, 255, 255, 255,
937       7,  71,   7,  45,  87,  41,  77,   7,   7,   7,   7,   7,   7,   7,   7,
938            7,   7,   7,   7,   7,   7,   7,
939       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
940            7,   7,   7,   7,   7,   7,   7,
941       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
942            7,   7,   7,   7,   7,   7,   7,
943       7,   7,   7,  87,  87,   7, 103,   7,   7,   7,   7,   7,   7,   7,   7,
944            7,   7,   7,   7,   7,   7,   7,
945       7,   7,  71,  87,  45,  41,  77,   7,   7,   7,   7,   7,   7,   7,   7,
946            7,   7,   7,   7,   7,   7,   7,
947       7,   7,   7,  87,  87,   7, 103,   7,   7,   7,   7,   7,   7,   7,   7,
948            7,   7,   7,   7,   7,   7,   7,
949       7,  37,   7,  67,  87,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
950            7,   7,   7,   7,   7,   7,   7,
951       7,  37,  37,  67,  67,  39,  79,   7,   7,   7,   7,   7,   7,   7,   7,
952            7,   7,   7,   7,   7,   7,   7,
953       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
954            7,   7,   7,   7,   7,   7,   7,
955       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
956            7,   7,   7,   7,   7,   7,   7,
957       7,   7,  37,  87,  67,  23,  91,   7,   7,   7,   7,   7,   7,   7,   7,
958            7,   7,   7,   7,   7,   7,   7,
959       7,  37,  37,  67,  67,  39,  79,   7,   7,   7,   7,   7,   7,   7,   7,
960            7,   7,   7,   7,   7,   7,   7,
961      99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,
962           99,  99,  99,  99,  99,  99,  99,
963     111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
964          111, 111, 111, 111, 111, 111, 111,
965   };
966   const int ldc = m;
967   int c_offset[] = {
968       6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477,
969       6477,  6477,  6477, 6477, 6477,  6477, 6477, 6477, 6477, 6477,
970       6477,  6477,
971   };
972   int c_mult_int[] = {
973       41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121, 41121, 41121,
974       41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
975       41121, 41121,
976   };
977   const int c_shift = 21;
978 
979   const int c_count = m * n;
980   std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
981   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
982                                                      ldc);
983   const OffsetColMap result_offset(c_offset, m);
984   const OffsetColMap result_mult_int(c_mult_int, m);
985   const int result_shift = c_shift;
986 
987   GemmContext gemm_context;
988   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
989       result_offset, result_mult_int, result_shift);
990   GemmWithOutputPipelinePC<uint8_t, uint8_t, DefaultL8R8BitDepthParams>(
991       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
992       output_pipeline);
993 
994   ResultStats stats;
995   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
996 
997   ResultStatsBounds bounds;
998   const bool good = CheckResultStatsBounds(stats, bounds);
999   printf("TestWithLargeDataPerChannelQuantization: %s\n",
1000          good ? "PASS" : "FAIL");
1001   ReportResultStats(stats, bounds);
1002   Check(good);
1003 }
1004 
1005 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1006 void TestWithSmallData() {
1007   const int m = 4;
1008   const int n = 2;
1009   const int k = 3;
1010   // Matrix A (LHS) is:
1011   // |  7 | 10 | 13 | 16 |
1012   // |  8 | 11 | 14 | 17 |
1013   // |  9 | 12 | 15 | 18 |
1014   const uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1015   // Matrix B (RHS) is:
1016   // |  1 |  3 |  5 |
1017   // |  2 |  4 |  6 |
1018   const uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1019   // Here are the results we expect, from hand calculations:
1020   // (1 * 7) + (3 * 8) + (5 * 9) = 76
1021   // (2 * 7) + (4 * 8) + (6 * 9) = 100
1022   // (1 * 10) + (3 * 11) + (5 * 12) = 103
1023   // (2 * 10) + (4 * 11) + (6 * 12) = 136
1024   // (1 * 13) + (3 * 14) + (5 * 15) = 130
1025   // (2 * 13) + (4 * 14) + (6 * 15) = 172
1026   // (1 * 16) + (3 * 17) + (5 * 18) = 157
1027   // (2 * 16) + (4 * 17) + (6 * 18) = 208
1028   // That means matrix C should be:
1029   // |  76 | 103 | 130 | 157 |
1030   // | 100 | 136 | 172 | 208 |
1031   const uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1032 
1033   const int c_count = m * n;
1034   std::unique_ptr<uint8_t[]> output_data(new uint8_t[c_count]);
1035 
1036   const bool is_a_transposed = true;
1037   const bool is_b_transposed = true;
1038   const bool is_c_transposed = true;
1039   const int lda = k;
1040   const int ldb = n;
1041   const int ldc = n;
1042 
1043   const int a_offset = 0;
1044   const int b_offset = 0;
1045   const int c_offset = 0;
1046   const int c_mult = 1;
1047   const int c_shift = 0;
1048 
1049   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1050       is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1051       a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1052       c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1053 
1054   ResultStats stats;
1055   GetResultStats(output_data.get(), expected_data, c_count, &stats);
1056 
1057   ResultStatsBounds bounds;
1058   const bool good = CheckResultStatsBounds(stats, bounds);
1059   printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1060   ReportResultStats(stats, bounds);
1061   Check(good);
1062 }
1063 
1064 // This is the most realistic test of how we'll be using the low-precision GEMM
1065 // function in applications. It takes in large input matrices that have been
1066 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1067 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1068                       int tolerance_median, int tolerance_max) {
1069   std::unique_ptr<uint8_t[]> output_data(new uint8_t[test_data::c_count]);
1070   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1071       test_data::is_a_transposed, test_data::is_b_transposed,
1072       test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1073       test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1074       test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1075       test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1076 
1077   ResultStats stats;
1078   GetResultStats(output_data.get(), test_data::expected_c_data,
1079                  test_data::c_count, &stats);
1080 
1081   ResultStatsBounds bounds;
1082   if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1083     bounds.med_unsigned_diff = tolerance_median;
1084     bounds.max_unsigned_diff = tolerance_max;
1085     bounds.med_signed_diff = 0;
1086     bounds.mean_signed_diff = 0.2f;
1087   }
1088 
1089   const bool good = CheckResultStatsBounds(stats, bounds);
1090   printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1091          GetBitDepthName(BitDepth));
1092   ReportResultStats(stats, bounds);
1093   Check(good);
1094 }
1095 
1096 template <MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1097 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1098                       int result_mult_int, int result_shift) {
1099   Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1100   Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1101   Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1102   MakeRandom(&lhs, 8);
1103   MakeRandom(&rhs, 8);
1104   const int lhs_offset = 12;
1105   const int rhs_offset = -34;
1106 
1107   // Test an empty pipeline, i.e. returning raw int32 accumulators.
1108   auto empty_pipeline = std::make_tuple();
1109   GemmContext context;
1110   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1111       &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1112       rhs_offset, empty_pipeline);
1113 
1114   for (int r = 0; r < rows; r++) {
1115     for (int c = 0; c < cols; c++) {
1116       std::int32_t expected = 0;
1117       for (int d = 0; d < depth; d++) {
1118         std::int32_t lhs_val =
1119             static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1120         std::int32_t rhs_val =
1121             static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1122         expected += lhs_val * rhs_val;
1123       }
1124       Check(expected == result_raw_int32(r, c));
1125     }
1126   }
1127 
1128   // Test a pipeline with only the quantize-down stage, still returning
1129   // unclamped (but scaled) int32's
1130   OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1131   quantize_down_stage.result_offset = result_offset;
1132   quantize_down_stage.result_mult_int = result_mult_int;
1133   quantize_down_stage.result_shift = result_shift;
1134   auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1135   Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1136   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1137       &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1138       lhs_offset, rhs_offset, quantize_down_pipeline);
1139 
1140   std::uint64_t sum = 0;
1141   for (int r = 0; r < rows; r++) {
1142     for (int c = 0; c < cols; c++) {
1143       std::int32_t raw = result_raw_int32(r, c);
1144       const std::int32_t rounding =
1145           (result_shift < 1) ? 0 : (1 << (result_shift - 1));
1146       std::int32_t expected =
1147           ((raw + result_offset) * result_mult_int + rounding) >> result_shift;
1148       Check(expected == result_quantized_down_int32(r, c));
1149       sum += expected;
1150     }
1151   }
1152   std::uint64_t avg = sum / (rows * cols);
1153   // Test that the average quantized-down value falls reasonably in the
1154   // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1155   // adjusted.
1156   Check(avg >= 64 && avg <= 192);
1157 
1158   // Test the familiar default pipeline consisting of quantize-down and
1159   // clamp-and-cast-to-uint8.
1160   OutputStageSaturatingCastToUint8 saturating_cast_stage;
1161   auto quantize_down_and_saturating_cast_pipeline =
1162       std::make_tuple(quantize_down_stage, saturating_cast_stage);
1163   Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1164                                                                           cols);
1165   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1166       &context, lhs.const_map(), rhs.const_map(),
1167       &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1168       quantize_down_and_saturating_cast_pipeline);
1169 
1170   for (int r = 0; r < rows; r++) {
1171     for (int c = 0; c < cols; c++) {
1172       std::int32_t quantized = result_quantized_down_int32(r, c);
1173       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1174       Check(expected == result_quantized_down_saturated_uint8(r, c));
1175     }
1176   }
1177 
1178   // Test a bias-addition with row-vector
1179   std::vector<std::int32_t> row_vector_data(cols);
1180   for (int i = 0; i < cols; i++) {
1181     row_vector_data[i] = (Random() % 1000) - 500;
1182   }
1183   typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1184   RowVectorMap row_vector_map(row_vector_data.data(), cols);
1185   OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1186   row_bias_addition_stage.bias_vector = row_vector_map;
1187   auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1188   Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1189   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1190       &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1191       lhs_offset, rhs_offset, row_bias_addition_pipeline);
1192   for (int r = 0; r < rows; r++) {
1193     for (int c = 0; c < cols; c++) {
1194       std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1195       Check(expected == result_of_row_bias_addition(r, c));
1196     }
1197   }
1198 
1199   // Test a bias-addition with column-vector
1200   std::vector<std::int32_t> col_vector_data(rows);
1201   for (int i = 0; i < rows; i++) {
1202     col_vector_data[i] = (Random() % 1000) - 500;
1203   }
1204   typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1205   ColVectorMap col_vector_map(col_vector_data.data(), rows);
1206   OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1207   col_bias_addition_stage.bias_vector = col_vector_map;
1208   auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1209   Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1210   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1211       &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1212       lhs_offset, rhs_offset, col_bias_addition_pipeline);
1213   for (int r = 0; r < rows; r++) {
1214     for (int c = 0; c < cols; c++) {
1215       std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1216       Check(expected == result_of_col_bias_addition(r, c));
1217     }
1218   }
1219 
1220   // Test a clamp
1221   OutputStageClamp clamp_stage;
1222   // Determine min and max of raw int32 accumulators
1223   std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1224   std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1225   for (int r = 0; r < rows; r++) {
1226     for (int c = 0; c < cols; c++) {
1227       raw_min = std::min(raw_min, result_raw_int32(r, c));
1228       raw_max = std::max(raw_max, result_raw_int32(r, c));
1229     }
1230   }
1231   // Pick some interesting clamp min/max bounds
1232   clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1233   clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1234   assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1235          clamp_stage.max <= raw_max);
1236   auto clamp_pipeline = std::make_tuple(clamp_stage);
1237   Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1238   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1239       &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1240       rhs_offset, clamp_pipeline);
1241   for (int r = 0; r < rows; r++) {
1242     for (int c = 0; c < cols; c++) {
1243       std::int32_t raw = result_raw_int32(r, c);
1244       std::int32_t expected =
1245           std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1246       Check(expected == result_clamped(r, c));
1247     }
1248   }
1249 
1250   // Test tanh
1251   OutputStageTanh tanh_stage;
1252   const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1253   const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1254   tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1255   tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1256   auto tanh_pipeline = std::make_tuple(tanh_stage);
1257   Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1258   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1259       &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1260       rhs_offset, tanh_pipeline);
1261   for (int r = 0; r < rows; r++) {
1262     for (int c = 0; c < cols; c++) {
1263       std::int32_t raw = result_raw_int32(r, c);
1264       double real_input =
1265           double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1266       double expected = std::tanh(real_input);
1267       std::int32_t actual_int32 = result_tanh(r, c);
1268       double actual =
1269           double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1270       Check(std::abs(expected - actual) < 2e-4);
1271     }
1272   }
1273 
1274   // Test a pipeline with bias and clamp
1275   auto bias_clamp_pipeline =
1276       std::make_tuple(col_bias_addition_stage, clamp_stage);
1277   Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1278   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1279       &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1280       lhs_offset, rhs_offset, bias_clamp_pipeline);
1281   for (int r = 0; r < rows; r++) {
1282     for (int c = 0; c < cols; c++) {
1283       std::int32_t raw = result_raw_int32(r, c);
1284       std::int32_t biased = raw + col_vector_data[r];
1285       std::int32_t expected =
1286           std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1287       Check(expected == result_biased_clamped(r, c));
1288     }
1289   }
1290 
1291   // Test a full pipeline with bias and clamp and quantization down to 8bit
1292   // result
1293   auto bias_clamp_quantize_cast_pipeline =
1294       std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1295                       saturating_cast_stage);
1296   Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1297       rows, cols);
1298   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1299       &context, lhs.const_map(), rhs.const_map(),
1300       &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1301       bias_clamp_quantize_cast_pipeline);
1302   for (int r = 0; r < rows; r++) {
1303     for (int c = 0; c < cols; c++) {
1304       const std::int32_t rounding =
1305           (result_shift < 1) ? 0 : (1 << (result_shift - 1));
1306       std::int32_t quantized =
1307           ((result_biased_clamped(r, c) + result_offset) * result_mult_int +
1308            rounding) >>
1309           result_shift;
1310       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1311       Check(expected == result_biased_clamped_quantized_casted(r, c));
1312     }
1313   }
1314 
1315   printf("TestOutputStages: PASS with ResultOrder=%s\n",
1316          OrderName(ResultOrder));
1317 }
1318 
1319 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
TestExhaustively()1320 void TestExhaustively() {
1321   GemmContext context;
1322 
1323   // Test the internal GEMM interfaces
1324   test_gemm<SingleThreadGemmWrapper<
1325       DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>,
1326       std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1327 
1328   test_gemm<MultiThreadGemmWrapper<
1329       DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>,
1330       std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1331 
1332   // Test the public GEMM interfaces
1333   test_gemm<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context);
1334 
1335   test_gemm<EightBitIntGemmWrapper<uint8_t,
1336                                    eight_bit_int_gemm::BitDepthSetting::A8B8>>(
1337       &context);
1338 
1339   // Test GEMV cases (internal interfaces)
1340   test_gemv<SingleThreadGemmWrapper<
1341       DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>,
1342       std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1343 
1344   test_gemv<MultiThreadGemmWrapper<
1345       DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>,
1346       std::uint8_t, DefaultL8R8BitDepthParams>>(&context);
1347 
1348   // Test GEMV cases (public interfaces)
1349   test_gemv<PublicGemmWrapper<uint8_t, DefaultL8R8BitDepthParams>>(&context);
1350 
1351   test_gemv<EightBitIntGemmWrapper<uint8_t,
1352                                    eight_bit_int_gemm::BitDepthSetting::A8B8>>(
1353       &context);
1354 
1355   // Test other bit depths
1356   // L7R5
1357   test_gemm<SingleThreadGemmWrapper<
1358       DefaultKernel<KernelFamily::Gemm, DefaultL7R5BitDepthParams>,
1359       std::uint8_t, DefaultL7R5BitDepthParams>>(&context);
1360 
1361   test_gemv<SingleThreadGemmWrapper<
1362       DefaultKernel<KernelFamily::Gemv, DefaultL7R5BitDepthParams>,
1363       std::uint8_t, DefaultL7R5BitDepthParams>>(&context);
1364 
1365   test_gemm<EightBitIntGemmWrapper<std::uint8_t,
1366                                    eight_bit_int_gemm::BitDepthSetting::A5B7>>(
1367       &context);
1368 
1369   // Test specific kernels with various different formats,
1370   // to exercises corner cases especially in the packing code.
1371   test_gemm_kernel<
1372       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1373                                    KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1374       &context);
1375 
1376   test_gemm_kernel<
1377       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1378                                    KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1379       &context);
1380 
1381   test_gemm_kernel<
1382       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1383                                    KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1384       &context);
1385 
1386   test_gemm_kernel<ReferenceKernel<KernelFormat<
1387       KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1388       KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1389 
1390   test_gemm_kernel<ReferenceKernel<KernelFormat<
1391       KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1392       KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1393 
1394   test_gemm_kernel<ReferenceKernel<KernelFormat<
1395       KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1396       KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1397 
1398   test_gemm_kernel<ReferenceKernel<KernelFormat<
1399       KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1400       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1401 
1402   test_gemm_kernel<ReferenceKernel<KernelFormat<
1403       KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1404       KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1405 
1406   test_gemm_kernel<ReferenceKernel<KernelFormat<
1407       KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1408       KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1409 }
1410 #endif  // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1411 
test()1412 void test() {
1413 #ifdef GEMMLOWP_TEST_PROFILE
1414   RegisterCurrentThreadForProfiling();
1415   StartProfiling();
1416 #endif
1417 
1418   // Run a first quick test against hand-calculated data.
1419   TestWithSmallData();
1420 
1421 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1422   TestExhaustively();
1423 #endif
1424 
1425   // Run against actual data from a network evaluation.
1426   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1427   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1428 
1429   // Test non-default output pipelines with various combinations of
1430   // output stages.
1431   TestOutputStages<MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1432   TestOutputStages<MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1433 
1434   // Test per channel quantization.
1435   TestWithSmallDataPerChannelQuantization();
1436   TestWithLargeDataPerChannelQuantization();
1437 #ifdef GEMMLOWP_TEST_PROFILE
1438   FinishProfiling();
1439 #endif
1440 
1441   std::cerr << "All tests passed." << std::endl;
1442 
1443   // We have been testing the eight_bit_int_gemm, so we should free its
1444   // persistent
1445   // resources now to avoid having leak-checking tools report leaks.
1446   eight_bit_int_gemm::FreePersistentResources();
1447 }
1448 
1449 }  // end namespace gemmlowp
1450 
1451 // For iOS, we need to define our own main(), so skip it here.
1452 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1453 int main() { gemmlowp::test(); }
1454 #endif
1455