• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 
17 #include <array>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28 
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32 
33 namespace gemmlowp {
34 
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36                               bool transpose_c, int m, int n, int k,
37                               const std::uint8_t* a, std::int32_t a_offset,
38                               int lda, const std::uint8_t* b,
39                               std::int32_t b_offset, int ldb, std::uint8_t* c,
40                               std::int32_t c_offset, std::int32_t c_mult_int,
41                               std::int32_t c_shift, int ldc) {
42   ScopedProfilingLabel("ReferenceEightBitIntGemm");
43   assert((c_shift >= 0) && (c_shift <= 32));
44 
45   assert(a != nullptr);
46   assert(b != nullptr);
47   assert(c != nullptr);
48 
49   int a_i_stride;
50   int a_l_stride;
51   if (transpose_a) {
52     a_i_stride = lda;
53     a_l_stride = 1;
54   } else {
55     a_i_stride = 1;
56     a_l_stride = lda;
57   }
58   int b_j_stride;
59   int b_l_stride;
60   if (transpose_b) {
61     b_j_stride = 1;
62     b_l_stride = ldb;
63   } else {
64     b_j_stride = ldb;
65     b_l_stride = 1;
66   }
67   int c_i_stride;
68   int c_j_stride;
69   if (transpose_c) {
70     c_i_stride = ldc;
71     c_j_stride = 1;
72   } else {
73     c_i_stride = 1;
74     c_j_stride = ldc;
75   }
76   int i, j, l;
77 
78   const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
79 
80   for (j = 0; j < n; j++) {
81     for (i = 0; i < m; i++) {
82       std::int32_t total = 0;
83       for (l = 0; l < k; l++) {
84         const int a_index = i * a_i_stride + l * a_l_stride;
85         const std::uint8_t a_as_byte = a[a_index];
86         const std::int32_t a_as_int =
87             static_cast<std::int32_t>(a_as_byte) + a_offset;
88         const int b_index = j * b_j_stride + l * b_l_stride;
89         const std::uint8_t b_as_byte = b[b_index];
90         const std::int32_t b_as_int =
91             static_cast<std::int32_t>(b_as_byte) + b_offset;
92         const std::int32_t mult_as_int = a_as_int * b_as_int;
93         total += mult_as_int;
94       }
95       std::int32_t output =
96           (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
97       if (output > 255) {
98         output = 255;
99       }
100       if (output < 0) {
101         output = 0;
102       }
103       const int c_index = i * c_i_stride + j * c_j_stride;
104       c[c_index] = static_cast<std::uint8_t>(output);
105     }
106   }
107 }
108 
109 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
110 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
111 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
112 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
113 
114 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
115 // interface, so we can use the same testing code to test all of them
116 
117 template <typename Kernel, typename Scalar, typename tBitDepthParams>
118 struct SingleThreadGemmWrapper {
119   typedef tBitDepthParams BitDepthParams;
120 
Namegemmlowp::SingleThreadGemmWrapper121   static const char* Name() {
122     static char buf[256];
123     snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
124     return buf;
125   }
126 
127   typedef SingleThreadGemmContext Context;
128 
129   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper130   static bool Gemm(Context* context,
131                    const MatrixMap<const Scalar, LhsOrder>& lhs,
132                    const MatrixMap<const Scalar, RhsOrder>& rhs,
133                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
134                    int rhs_offset, int result_offset, int result_mult_int,
135                    int result_shift) {
136     ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
137     const int rows = lhs.rows();
138     const int cols = rhs.cols();
139     if (rows < cols) {
140       // SingleThreadGemm is never called with rows < cols.
141       // That case is handled earlier.
142       return false;
143     }
144     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
145     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
146     SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
147                      LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
148                      OffsetRowDup>(
149         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
150         rhs_offset_vector,
151         MakeStandardOutputPipeline(result_offset, result_mult_int,
152                                    result_shift));
153     return true;
154   }
155 };
156 
157 template <typename Kernel, typename Scalar, typename tBitDepthParams>
158 struct MultiThreadGemmWrapper {
159   typedef tBitDepthParams BitDepthParams;
160 
Namegemmlowp::MultiThreadGemmWrapper161   static const char* Name() {
162     static char buf[256];
163     snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
164     return buf;
165   }
166 
167   typedef MultiThreadGemmContext Context;
168 
169   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper170   static bool Gemm(Context* context,
171                    const MatrixMap<const Scalar, LhsOrder>& lhs,
172                    const MatrixMap<const Scalar, RhsOrder>& rhs,
173                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
174                    int rhs_offset, int result_offset, int result_mult_int,
175                    int result_shift) {
176     ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
177     context->set_max_num_threads(0);
178     const int rows = lhs.rows();
179     const int cols = rhs.cols();
180     if (rows < cols) {
181       // SingleThreadGemm is never called with rows < cols.
182       // That case is handled earlier.
183       return false;
184     }
185     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
186     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
187     MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
188                     LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
189                     OffsetRowDup>(
190         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
191         rhs_offset_vector,
192         MakeStandardOutputPipeline(result_offset, result_mult_int,
193                                    result_shift));
194     return true;
195   }
196 };
197 
198 template <typename Scalar, typename tBitDepthParams>
199 struct PublicGemmWrapper {
200   typedef tBitDepthParams BitDepthParams;
201 
Namegemmlowp::PublicGemmWrapper202   static const char* Name() { return "public Gemm"; }
203 
204   typedef GemmContext Context;
205 
206   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper207   static bool Gemm(Context* context,
208                    const MatrixMap<const Scalar, LhsOrder>& lhs,
209                    const MatrixMap<const Scalar, RhsOrder>& rhs,
210                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
211                    int rhs_offset, int result_offset, int result_mult_int,
212                    int result_shift) {
213     ScopedProfilingLabel("PublicGemmWrapper::Gemm");
214     gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
215                    ResultOrder>(context, lhs, rhs, result, lhs_offset,
216                                 rhs_offset, result_offset, result_mult_int,
217                                 result_shift);
218     return true;
219   }
220 };
221 
222 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
223 struct BitDepthParamsForSettings {};
224 
225 template <>
226 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
227     : DefaultL8R8BitDepthParams {};
228 
229 template <>
230 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
231     : DefaultL7R5BitDepthParams {};
232 
233 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
234 struct EightBitIntGemmWrapper {
235   typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
236 
Namegemmlowp::EightBitIntGemmWrapper237   static const char* Name() { return "EightBitIntGemm"; }
238 
239   typedef void Context;
240 
241   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper242   static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
243                    const MatrixMap<const Scalar, RhsOrder>& rhs,
244                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
245                    int rhs_offset, int result_offset, int result_mult_int,
246                    int result_shift) {
247     ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
248     const bool transpose_c = ResultOrder == MapOrder::RowMajor;
249     const bool transpose_a = LhsOrder == MapOrder::RowMajor;
250     const bool transpose_b = RhsOrder == MapOrder::RowMajor;
251     eight_bit_int_gemm::EightBitIntGemm(
252         transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
253         lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
254         rhs_offset, rhs.stride(), result->data(), result_offset,
255         result_mult_int, result_shift, result->stride(), BitDepth);
256     return true;
257   }
258 };
259 
260 template <typename Scalar>
261 struct ReferenceEightBitIntGemmWrapper {
262   typedef DefaultL8R8BitDepthParams BitDepthParams;
263 
Namegemmlowp::ReferenceEightBitIntGemmWrapper264   static const char* Name() { return "ReferenceEightBitIntGemm"; }
265 
266   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper267   static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
268                    const MatrixMap<const Scalar, LhsOrder>& lhs,
269                    const MatrixMap<const Scalar, RhsOrder>& rhs,
270                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
271                    int rhs_offset, int result_offset, int result_mult_int,
272                    int result_shift) {
273     ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
274     ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
275                              rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
276                              lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
277                              result->data(), result_offset, result_mult_int,
278                              result_shift, result->stride());
279     return true;
280   }
281 };
282 
OrderName(MapOrder order)283 const char* OrderName(MapOrder order) {
284   return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
285 }
286 
287 struct ResultStats {
ResultStatsgemmlowp::ResultStats288   ResultStats()
289       : count(0),
290         med_val(0),
291         mean_signed_diff(0),
292         med_signed_diff(0),
293         med_unsigned_diff(0),
294         max_unsigned_diff(0) {}
295 
296   int count;
297   int med_val;
298   float mean_signed_diff;
299   int med_signed_diff;
300   int med_unsigned_diff;
301   int max_unsigned_diff;
302 
303   std::vector<int> count_diff_by_pot_slice;
304 };
305 
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)306 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
307                     size_t count, ResultStats* stats) {
308   ScopedProfilingLabel("GetResultStats");
309   std::vector<std::uint8_t> results;
310   std::vector<std::int16_t> signed_diffs;
311   std::vector<std::uint8_t> unsigned_diffs;
312   std::int64_t signed_diffs_sum = 0;
313   for (size_t i = 0; i < count; i++) {
314     results.push_back(actual[i]);
315     std::int16_t signed_diff = actual[i] - expected[i];
316     signed_diffs.push_back(signed_diff);
317     unsigned_diffs.push_back(std::abs(signed_diff));
318     signed_diffs_sum += signed_diff;
319   }
320 
321   std::sort(results.begin(), results.end());
322   std::sort(signed_diffs.begin(), signed_diffs.end());
323   std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
324 
325   const size_t middle = count / 2;
326 
327   stats->count = count;
328   stats->med_val = results[middle];
329   stats->mean_signed_diff = float(signed_diffs_sum) / count;
330   stats->med_signed_diff = signed_diffs[middle];
331   stats->med_unsigned_diff = unsigned_diffs[middle];
332   stats->max_unsigned_diff = unsigned_diffs.back();
333 
334   // Size 9 for 9 different POT values: 2^0, ..., 2^8
335   stats->count_diff_by_pot_slice.resize(9);
336   auto cur = unsigned_diffs.begin();
337   size_t checksum = 0;
338   for (int exponent = 0; exponent < 9; exponent++) {
339     int pot = 1 << exponent;
340     auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
341     checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
342     cur = next;
343   }
344   assert(checksum == count);
345 }
346 
347 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds348   ResultStatsBounds()
349       : mean_signed_diff(0),
350         med_signed_diff(0),
351         med_unsigned_diff(0),
352         max_unsigned_diff(0) {}
353 
354   float mean_signed_diff;
355   int med_signed_diff;
356   int med_unsigned_diff;
357   int max_unsigned_diff;
358 };
359 
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)360 bool CheckResultStatsBounds(const ResultStats& stats,
361                             const ResultStatsBounds& bounds) {
362   return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
363          stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
364          std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
365          std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
366 }
367 
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)368 void ReportResultStats(const ResultStats& stats,
369                        const ResultStatsBounds& bounds) {
370   printf("    number of matrix entries: %d\n", stats.count);
371   printf("    median value: %d\n", stats.med_val);
372   printf("    median unsigned diff: %d (tolerating %d)\n",
373          stats.med_unsigned_diff, bounds.med_unsigned_diff);
374   printf("    max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
375          bounds.max_unsigned_diff);
376   printf("    median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
377          bounds.med_signed_diff);
378   printf("    mean signed diff: %.3g (tolerating %.3g)\n",
379          stats.mean_signed_diff, bounds.mean_signed_diff);
380 
381   printf("No error: %.2f %% of entries\n",
382          100.f * stats.count_diff_by_pot_slice[0] / stats.count);
383   for (int exponent = 1; exponent < 9; exponent++) {
384     printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
385            (1 << exponent) - 1,
386            100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
387   }
388 }
389 
390 // Our approach to choosing result_shift values for testing, is bisection.
391 // This function takes an interval, [result_shift_min .. result_shift_max].
392 // If too much saturation occurred in either direction, it bisects accordingly,
393 // recursing until the interval contains only one value.
394 // The primary reason why we prefer this over computing optimal shift values,
395 // is that we actually want to exercise some saturation, as there is nontrivial
396 // code handling that in gemmlowp.
397 // Secondarily, this is faster than computing optimal shifts, since in 90% of
398 // cases the first-tried shift value 16 turns out to be good enough.
399 template <typename GemmWrapper, typename LhsType, typename RhsType,
400           typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)401 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
402                     const RhsType& rhs, ResultType* result, int lhs_offset,
403                     int rhs_offset, int result_offset, int result_mult_int,
404                     int result_shift_min, int result_shift_max) {
405   const int rows = lhs.rows();
406   const int cols = rhs.cols();
407   Check(lhs.cols() == rhs.rows());
408   const int depth = lhs.cols();
409 
410   const int result_shift = (result_shift_min + result_shift_max) / 2;
411 
412   if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
413                          &result->map(), lhs_offset, rhs_offset, result_offset,
414                          result_mult_int, result_shift)) {
415     // Internal GEMM functions are not required to handle all cases
416     // (e.g. rows < cols) as these are supposed to have been handled
417     // ahead of them. Their test wrappers return false in that case.
418     return;
419   }
420 
421   typedef typename ResultType::Scalar Scalar;
422   static const MapOrder kLhsOrder = LhsType::kOrder;
423   static const MapOrder kRhsOrder = RhsType::kOrder;
424   static const MapOrder kResultOrder = ResultType::kOrder;
425   ResultType ref_result(rows, cols);
426   const bool transpose_c = kResultOrder == MapOrder::RowMajor;
427   const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
428   const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
429   ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
430       transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
431       &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
432       result_shift);
433 
434   typedef typename GemmWrapper::BitDepthParams BitDepthParams;
435 
436   ResultStats stats;
437   GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
438 
439   // Adjust shifts until we get meaningful results
440   int new_result_shift_min = result_shift_min;
441   int new_result_shift_max = result_shift_max;
442   bool retry = false;
443 
444   if (stats.med_val < 32) {
445     new_result_shift_max = (result_shift_min + result_shift_max) / 2;
446     retry = true;
447   }
448 
449   if (stats.med_val > 224) {
450     new_result_shift_min = (result_shift_min + result_shift_max) / 2;
451     retry = true;
452   }
453 
454   if (retry) {
455     if (result_shift_min != result_shift_max) {
456       test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
457                                   rhs_offset, result_offset, result_mult_int,
458                                   new_result_shift_min, new_result_shift_max);
459     }
460     return;
461   }
462 
463   ResultStatsBounds bounds;
464 
465   // Check results
466   const bool good = CheckResultStatsBounds(stats, bounds);
467 
468   printf(
469       "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
470       good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
471       OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
472       lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
473 
474   if (!good) {
475     ReportResultStats(stats, bounds);
476 
477     int bad_coeffs_printed = 0;
478     for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
479       for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
480         if (ref_result(r, c) != (*result)(r, c)) {
481           printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
482                  ref_result(r, c), (*result)(r, c));
483           bad_coeffs_printed++;
484         }
485       }
486     }
487   }
488 
489   Check(good);
490 }
491 
492 template <typename GemmWrapper, typename LhsType, typename RhsType,
493           typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)494 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
495                const RhsType& rhs, ResultType* result, int lhs_offset,
496                int rhs_offset, int result_offset, int result_mult_int) {
497   test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
498                               result_offset, result_mult_int, 0, 32);
499 }
500 
501 enum class WhatParamsToTest {
502   All,
503   OnlyGenericCase,
504 };
505 
506 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
507           MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)508 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
509                int cols, WhatParamsToTest params_to_test) {
510   typedef std::uint8_t Scalar;
511   typedef Matrix<Scalar, LhsOrder> LhsType;
512   using BitDepthParams = typename GemmWrapper::BitDepthParams;
513   LhsType lhs(rows, depth);
514   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
515   typedef Matrix<Scalar, RhsOrder> RhsType;
516   RhsType rhs(depth, cols);
517   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
518   typedef Matrix<Scalar, ResultOrder> ResultType;
519   ResultType result(rows, cols);
520   MakeZero(&result);
521 
522   if (params_to_test == WhatParamsToTest::All) {
523     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
524     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
525     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
526     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
527     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
528     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
529     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
530   }
531   test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
532 }
533 
534 enum class WhatOrdersToTest { All, OnlyRCC };
535 
536 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)537 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
538                int cols, WhatParamsToTest params_to_test,
539                WhatOrdersToTest orders_to_test) {
540 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder)         \
541   do {                                                             \
542     test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
543               MapOrder::ResultOrder>(context, rows, depth, cols,   \
544                                      params_to_test);              \
545   } while (false)
546 
547   if (orders_to_test == WhatOrdersToTest::All) {
548     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
549     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
550     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
551     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
552 
553     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
554     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
555     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
556     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
557   } else {
558     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
559   }
560 
561 #undef GEMMLOWP_ONE_TEST
562 }
563 
564 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)565 void test_gemm_kernel(MultiThreadGemmContext* context) {
566   typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
567                                  DefaultL8R8BitDepthParams>
568       GemmWrapper;
569   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
570                          WhatOrdersToTest::OnlyRCC);
571   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
572                          WhatOrdersToTest::OnlyRCC);
573   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
574                          WhatOrdersToTest::OnlyRCC);
575   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
576                          WhatOrdersToTest::OnlyRCC);
577   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
578                          WhatOrdersToTest::OnlyRCC);
579   test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
580                          WhatOrdersToTest::OnlyRCC);
581   test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
582                          WhatOrdersToTest::OnlyRCC);
583   test_gemm<GemmWrapper>(context, 200, 200, 200,
584                          WhatParamsToTest::OnlyGenericCase,
585                          WhatOrdersToTest::All);
586   test_gemm<GemmWrapper>(context, 50, 5000, 50,
587                          WhatParamsToTest::OnlyGenericCase,
588                          WhatOrdersToTest::OnlyRCC);
589 }
590 
591 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)592 void test_gemm(typename GemmWrapper::Context* context) {
593   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
594                          WhatOrdersToTest::OnlyRCC);
595   test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
596                          WhatOrdersToTest::OnlyRCC);
597   test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
598                          WhatOrdersToTest::OnlyRCC);
599   test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
600                          WhatOrdersToTest::OnlyRCC);
601   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
602                          WhatOrdersToTest::OnlyRCC);
603   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
604                          WhatOrdersToTest::OnlyRCC);
605   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
606                          WhatOrdersToTest::OnlyRCC);
607   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
608                          WhatOrdersToTest::OnlyRCC);
609   test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
610                          WhatOrdersToTest::OnlyRCC);
611   test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
612                          WhatOrdersToTest::OnlyRCC);
613   test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
614                          WhatOrdersToTest::OnlyRCC);
615   test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
616                          WhatOrdersToTest::OnlyRCC);
617   test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
618                          WhatOrdersToTest::All);
619   test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
620                          WhatOrdersToTest::OnlyRCC);
621   test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
622                          WhatOrdersToTest::OnlyRCC);
623   test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
624                          WhatOrdersToTest::All);
625   test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
626                          WhatOrdersToTest::OnlyRCC);
627 
628   test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
629                          WhatOrdersToTest::OnlyRCC);
630   test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
631                          WhatOrdersToTest::OnlyRCC);
632   test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
633                          WhatOrdersToTest::OnlyRCC);
634   test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
635                          WhatOrdersToTest::OnlyRCC);
636   test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
637                          WhatOrdersToTest::OnlyRCC);
638   test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
639                          WhatOrdersToTest::OnlyRCC);
640 
641   test_gemm<GemmWrapper>(context, 512, 512, 512,
642                          WhatParamsToTest::OnlyGenericCase,
643                          WhatOrdersToTest::OnlyRCC);
644   test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
645                          WhatParamsToTest::OnlyGenericCase,
646                          WhatOrdersToTest::OnlyRCC);
647   test_gemm<GemmWrapper>(context, 567, 2345, 123,
648                          WhatParamsToTest::OnlyGenericCase,
649                          WhatOrdersToTest::OnlyRCC);
650   test_gemm<GemmWrapper>(context, 100, 5000, 100,
651                          WhatParamsToTest::OnlyGenericCase,
652                          WhatOrdersToTest::OnlyRCC);
653   test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
654                          WhatOrdersToTest::OnlyRCC);
655   test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
656                          WhatOrdersToTest::OnlyRCC);
657   test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
658                          WhatOrdersToTest::OnlyRCC);
659   test_gemm<GemmWrapper>(context, 1, 1000, 1000,
660                          WhatParamsToTest::OnlyGenericCase,
661                          WhatOrdersToTest::OnlyRCC);
662   test_gemm<GemmWrapper>(context, 1000, 1, 1000,
663                          WhatParamsToTest::OnlyGenericCase,
664                          WhatOrdersToTest::OnlyRCC);
665   test_gemm<GemmWrapper>(context, 1000, 1000, 1,
666                          WhatParamsToTest::OnlyGenericCase,
667                          WhatOrdersToTest::OnlyRCC);
668   test_gemm<GemmWrapper>(context, 777, 3456, 1,
669                          WhatParamsToTest::OnlyGenericCase,
670                          WhatOrdersToTest::OnlyRCC);
671   test_gemm<GemmWrapper>(context, 4567, 555, 1,
672                          WhatParamsToTest::OnlyGenericCase,
673                          WhatOrdersToTest::OnlyRCC);
674 
675   // Test all storage orders
676   test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
677                          WhatOrdersToTest::All);
678   test_gemm<GemmWrapper>(context, 300, 400, 500,
679                          WhatParamsToTest::OnlyGenericCase,
680                          WhatOrdersToTest::All);
681 }
682 
683 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)684 void test_gemv(typename GemmWrapper::Context* context) {
685   test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
686                          WhatOrdersToTest::OnlyRCC);
687   test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
688                          WhatOrdersToTest::OnlyRCC);
689   test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
690                          WhatOrdersToTest::OnlyRCC);
691   test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
692                          WhatOrdersToTest::OnlyRCC);
693   test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
694                          WhatOrdersToTest::OnlyRCC);
695   test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
696                          WhatOrdersToTest::OnlyRCC);
697   test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
698                          WhatOrdersToTest::OnlyRCC);
699   test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
700                          WhatOrdersToTest::OnlyRCC);
701   test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
702                          WhatOrdersToTest::OnlyRCC);
703   test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
704                          WhatOrdersToTest::OnlyRCC);
705   test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
706                          WhatOrdersToTest::OnlyRCC);
707   test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
708                          WhatOrdersToTest::OnlyRCC);
709 
710   // Test all storage orders
711   test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
712                          WhatOrdersToTest::All);
713   test_gemm<GemmWrapper>(context, 300, 400, 1,
714                          WhatParamsToTest::OnlyGenericCase,
715                          WhatOrdersToTest::All);
716 }
717 
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)718 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
719   switch (b) {
720     case eight_bit_int_gemm::BitDepthSetting::A8B8:
721       return "Lhs: 8 bit, Rhs: 8 bit";
722     case eight_bit_int_gemm::BitDepthSetting::A5B7:
723       return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
724     default:
725       abort();
726       return nullptr;
727   }
728 }
729 
730 // Runs a small set of hand-picked data for per-channel quantized data.
731 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
732 // image.
TestWithSmallDataPerChannelQuantization()733 void TestWithSmallDataPerChannelQuantization() {
734   const int m = 2;
735   const int n = 9;
736   const int k = 12;
737 
738   // 12 x 2, columnwise.
739   const std::uint8_t a_data[] = {0,  0,   0,   0,   0,  0,   0,   0,
740                                  0,  255, 255, 255, 64, 64,  64,  64,
741                                  64, 64,  0,   0,   0,  255, 255, 255};
742   const int lda = k;
743   int a_offset[] = {0, -64};
744   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
745   const OffsetColMap lhs_offset(a_offset, m);
746 
747   // 12 x 9, columnwise.
748   const std::uint8_t b_data[] = {
749       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,
750       0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   127,
751       127, 127, 0,   0,   0,   127, 127, 127, 0,   0,   0,   255, 255, 255,
752       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,
753       0,   0,   0,   0,   0,   0,   0,   127, 127, 127, 0,   0,   0,   127,
754       127, 127, 0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127,
755       0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127, 0,   0,
756       0,   127, 127, 127, 127, 127, 127, 127, 127, 127};
757   const int ldb = k;
758   int b_offset = -127;
759   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
760   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
761 
762   // 2 x 9, columnwise.
763   const std::uint8_t expected_c_data[] = {255, 255, 0,   0,   127, 159,
764                                           0,   64,  0,   64,  127, 159,
765                                           127, 127, 127, 127, 127, 127};
766   const int ldc = m;
767   int c_offset[] = {97155, 97346};
768   int c_mult_int[] = {2741, 2741};
769   const int c_shift = 21;
770 
771   const int c_count = m * n;
772   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
773   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
774                                                      ldc);
775   const OffsetColMap result_offset(c_offset, m);
776   const OffsetColMap result_mult_int(c_mult_int, m);
777   const int result_shift = c_shift;
778 
779   GemmContext gemm_context;
780   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
781       result_offset, result_mult_int, result_shift);
782   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
783                            DefaultL8R8BitDepthParams>(
784       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
785       output_pipeline);
786 
787   ResultStats stats;
788   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
789 
790   ResultStatsBounds bounds;
791   const bool good = CheckResultStatsBounds(stats, bounds);
792   printf("TestWithSmallDataPerChannelQuantization: %s\n",
793          good ? "PASS" : "FAIL");
794   ReportResultStats(stats, bounds);
795   Check(good);
796 }
797 
798 // Runs a larger set of hand-picked data for per-channel quantized data.
799 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
800 // image.  Right now, I have 7 different filters and 15 copies of the first
801 // filter to make sure NEON code path that processes 16 rows at a time is
802 // covered.
TestWithLargeDataPerChannelQuantization()803 void TestWithLargeDataPerChannelQuantization() {
804   const int m = 22;
805   const int n = 25;
806   const int k = 27;
807 
808   // 27 x 22, column-wise.
809   const std::uint8_t a_data[] = {
810       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
811       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
812       0,   0,   0,   0,   0,   0,   127, 127, 127, 255, 255, 255, 127, 127, 127,
813       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   127, 127, 127,
814       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
815       127, 127, 127, 0,   0,   0,   51,  51,  51,  51,  51,  51,  51,  51,  51,
816       0,   0,   0,   255, 255, 255, 0,   0,   0,   51,  51,  51,  51,  51,  51,
817       51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,  51,  51,  51,
818       255, 255, 255, 51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,
819       0,   0,   0,   64,  64,  64,  0,   0,   0,   64,  64,  64,  255, 255, 255,
820       64,  64,  64,  0,   0,   0,   64,  64,  64,  0,   0,   0,   36,  36,  36,
821       0,   0,   0,   36,  36,  36,  0,   0,   0,   255, 255, 255, 0,   0,   0,
822       36,  36,  36,  0,   0,   0,   36,  36,  36,  0,   0,   0,   0,   0,   0,
823       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
824       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
825       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
826       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
827       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
828       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
829       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
830       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
831       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
832       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
833       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
834       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
835       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
836       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
837       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
838       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
839       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
840       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
841       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
842       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
843       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
844       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
845       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
846       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
847       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
848       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
849       0,   0,   0,   0,   0,   0,   0,   0,   0,
850   };
851   const int lda = k;
852   int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
853                     0, 0, 0, 0,   0,   0, 0,   0, 0, 0, 0};
854   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
855   const OffsetColMap lhs_offset(a_offset, m);
856 
857   // 27 x 25, column-wise.
858   const std::uint8_t b_data[] = {
859       127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
860       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
861       127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
863       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
864       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
865       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
866       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
867       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
868       119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
869       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
870       127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
871       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
873       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
874       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
876       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
877       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
878       119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
879       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
880       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
882       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
883       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
885       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
886       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
887       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
888       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
889       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
890       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
891       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
892       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
894       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
895       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
897       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
898       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
899       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
900       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
901       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
902       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
903       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
904       119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
905       127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
906       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
907       127, 127, 127};
908   const int ldb = k;
909   int b_offset = -127;
910   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
911   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
912 
913   // 22 x 25, column-wise.
914   const std::uint8_t expected_c_data[] = {
915       7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,
916       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,
917       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
918       7,   37,  87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,
919       7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,
920       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
921       37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
922       7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,
923       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
924       87,  87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
925       7,   7,   7,   7,   7,   7,   71,  87,  45,  41,  77,  7,   7,   7,   7,
926       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,
927       87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
928       7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,
929       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,
930       23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
931       7,   7,   7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,
932       7,   7,   7,   7,   7,   7,   7,   7,   7,   255, 135, 135, 255, 255, 143,
933       255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
934       255, 7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,   7,
935       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,
936       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
937       7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,
938       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,
939       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
940       7,   71,  87,  45,  41,  77,  7,   7,   7,   7,   7,   7,   7,   7,   7,
941       7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,   7,
942       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
943       7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
944       7,   7,   7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,
945       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
946       87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
947       7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,   7,   7,
948       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,
949       67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
950       7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,
951       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   99,  99,  99,  99,  99,
952       99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,
953       99,  99,  111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
954       111, 111, 111, 111, 111, 111, 111, 111, 111,
955   };
956   const int ldc = m;
957   int c_offset[] = {
958       6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
959       6477, 6477,  6477,  6477, 6477, 6477,  6477, 6477, 6477, 6477, 6477,
960   };
961   int c_mult_int[] = {
962       41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
963       41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
964       41121, 41121, 41121, 41121, 41121, 41121,
965   };
966   const int c_shift = 21;
967 
968   const int c_count = m * n;
969   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
970   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
971                                                      ldc);
972   const OffsetColMap result_offset(c_offset, m);
973   const OffsetColMap result_mult_int(c_mult_int, m);
974   const int result_shift = c_shift;
975 
976   GemmContext gemm_context;
977   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
978       result_offset, result_mult_int, result_shift);
979   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
980                            DefaultL8R8BitDepthParams>(
981       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
982       output_pipeline);
983 
984   ResultStats stats;
985   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
986 
987   ResultStatsBounds bounds;
988   const bool good = CheckResultStatsBounds(stats, bounds);
989   printf("TestWithLargeDataPerChannelQuantization: %s\n",
990          good ? "PASS" : "FAIL");
991   ReportResultStats(stats, bounds);
992   Check(good);
993 }
994 
995 // Multithreading only activates when the result has more than 16 rows, and also
996 // (result rows) * (result cols) * depth >= 2 x 65 x 1024.  Size was selected
997 // to run in 3 threads.
998 //
999 // Based on the following floating point data:
1000 //   LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1001 //     1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1002 //     0.01, 0.02 in last 16 rows.
1003 //   RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1004 //   Varying boundaries were used for each 16 rows block of LHS, to test for
1005 //     correct indexing into offsets.
1006 //   Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1007 //     5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1008 //     16 rows.
TestMultithreadedPerChannelQuantization()1009 void TestMultithreadedPerChannelQuantization() {
1010   const int m = 64;
1011   const int n = 20;
1012   const int k = 160;
1013 
1014   // LHS, m x k.
1015   const std::array<std::int32_t, 4> lhs_offsets_terse{{
1016       0, -51, -85, -109,
1017   }};
1018   assert(lhs_offsets_terse.size() * 16 == m);
1019   const std::array<std::uint8_t, 4> lhs_first_el{{
1020       128, 153, 170, 182,
1021   }};
1022   assert(lhs_first_el.size() * 16 == m);
1023 
1024   // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1025   std::vector<std::uint8_t> a_data(m * k, 0);
1026   for (int i = 0; i < m; ++i) {
1027     a_data[i * k] = lhs_first_el[i / 16];
1028     a_data[i * k + 1] = 255;
1029     for (int j = 2; j < k; ++j) {
1030       a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1031     }
1032   }
1033 
1034   const int lda = k;
1035   // Given values at [i / 16].
1036   std::vector<std::int32_t> a_offset(m, 0);
1037   for (int i = 0; i < m; ++i) {
1038     a_offset[i] = lhs_offsets_terse[i / 16];
1039   }
1040 
1041   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1042   const OffsetColMap lhs_offset(&a_offset[0], m);
1043 
1044   // RHS, k x n.
1045   // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1046   std::vector<std::uint8_t> b_data(k * n, 0);
1047   b_data[0] = 128;
1048   b_data[1] = 255;
1049 
1050   const int ldb = k;
1051   std::int32_t b_offset = 0;
1052   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1053   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1054 
1055   // Result, m x n.
1056   // All zeros, except given values at (i / 16, 0).
1057   const std::array<std::uint8_t, 4> expected_c_terse{{
1058       142, 159, 182, 213,
1059   }};
1060   assert(expected_c_terse.size() * 16 == m);
1061   std::vector<std::uint8_t> expected_c_data(m * n, 0);
1062   for (int i = 0; i < m; ++i) {
1063     expected_c_data[i] = expected_c_terse[i / 16];
1064   }
1065 
1066   const int ldc = m;
1067   // All zeros.
1068   std::vector<std::int32_t> c_offset(m, 0);
1069   // Given values at [i / 16].
1070   const std::array<std::int32_t, 4> c_mult_int_terse{{
1071       3655, 5140, 7049, 9595,
1072   }};
1073   assert(c_mult_int_terse.size() * 16 == m);
1074   std::vector<std::int32_t> c_mult_int(m);
1075   for (int i = 0; i < m; ++i) {
1076     c_mult_int[i] = c_mult_int_terse[i / 16];
1077   }
1078 
1079   const int c_shift = 21;
1080 
1081   const int c_count = m * n;
1082   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1083   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1084                                                      ldc);
1085   const OffsetColMap result_offset(&c_offset[0], m);
1086   const OffsetColMap result_mult_int(&c_mult_int[0], m);
1087   const int result_shift = c_shift;
1088 
1089   GemmContext gemm_context;
1090   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1091       result_offset, result_mult_int, result_shift);
1092   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1093                            DefaultL8R8BitDepthParams>(
1094       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1095       output_pipeline);
1096 
1097   ResultStats stats;
1098   GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1099 
1100   ResultStatsBounds bounds;
1101   const bool good = CheckResultStatsBounds(stats, bounds);
1102   printf("TestMultithreadedPerChannelQuantization: %s\n",
1103          good ? "PASS" : "FAIL");
1104   ReportResultStats(stats, bounds);
1105   Check(good);
1106 }
1107 
1108 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1109 void TestWithSmallData() {
1110   const int m = 4;
1111   const int n = 2;
1112   const int k = 3;
1113   // Matrix A (LHS) is:
1114   // |  7 | 10 | 13 | 16 |
1115   // |  8 | 11 | 14 | 17 |
1116   // |  9 | 12 | 15 | 18 |
1117   const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1118   // Matrix B (RHS) is:
1119   // |  1 |  3 |  5 |
1120   // |  2 |  4 |  6 |
1121   const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1122   // Here are the results we expect, from hand calculations:
1123   // (1 * 7) + (3 * 8) + (5 * 9) = 76
1124   // (2 * 7) + (4 * 8) + (6 * 9) = 100
1125   // (1 * 10) + (3 * 11) + (5 * 12) = 103
1126   // (2 * 10) + (4 * 11) + (6 * 12) = 136
1127   // (1 * 13) + (3 * 14) + (5 * 15) = 130
1128   // (2 * 13) + (4 * 14) + (6 * 15) = 172
1129   // (1 * 16) + (3 * 17) + (5 * 18) = 157
1130   // (2 * 16) + (4 * 17) + (6 * 18) = 208
1131   // That means matrix C should be:
1132   // |  76 | 103 | 130 | 157 |
1133   // | 100 | 136 | 172 | 208 |
1134   const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1135 
1136   const int c_count = m * n;
1137   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1138 
1139   const bool is_a_transposed = true;
1140   const bool is_b_transposed = true;
1141   const bool is_c_transposed = true;
1142   const int lda = k;
1143   const int ldb = n;
1144   const int ldc = n;
1145 
1146   const int a_offset = 0;
1147   const int b_offset = 0;
1148   const int c_offset = 0;
1149   const int c_mult = 1;
1150   const int c_shift = 0;
1151 
1152   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1153       is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1154       a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1155       c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1156 
1157   ResultStats stats;
1158   GetResultStats(output_data.get(), expected_data, c_count, &stats);
1159 
1160   ResultStatsBounds bounds;
1161   const bool good = CheckResultStatsBounds(stats, bounds);
1162   printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1163   ReportResultStats(stats, bounds);
1164   Check(good);
1165 }
1166 
1167 // This is the most realistic test of how we'll be using the low-precision GEMM
1168 // function in applications. It takes in large input matrices that have been
1169 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1170 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1171                       int tolerance_median, int tolerance_max) {
1172   std::unique_ptr<std::uint8_t[]> output_data(
1173       new std::uint8_t[test_data::c_count]);
1174   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1175       test_data::is_a_transposed, test_data::is_b_transposed,
1176       test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1177       test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1178       test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1179       test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1180 
1181   ResultStats stats;
1182   GetResultStats(output_data.get(), test_data::expected_c_data,
1183                  test_data::c_count, &stats);
1184 
1185   ResultStatsBounds bounds;
1186   if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1187     bounds.med_unsigned_diff = tolerance_median;
1188     bounds.max_unsigned_diff = tolerance_max;
1189     bounds.med_signed_diff = 0;
1190     bounds.mean_signed_diff = 0.2f;
1191   }
1192 
1193   const bool good = CheckResultStatsBounds(stats, bounds);
1194   printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1195          GetBitDepthName(BitDepth));
1196   ReportResultStats(stats, bounds);
1197   Check(good);
1198 }
1199 
1200 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1201 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1202                       int result_mult_int, int result_shift) {
1203   Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1204   Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1205   Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1206   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1207   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1208   const int lhs_offset = 12;
1209   const int rhs_offset = -34;
1210 
1211   // Test an empty pipeline, i.e. returning raw int32 accumulators.
1212   auto empty_pipeline = std::make_tuple();
1213   GemmContext context;
1214   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1215       &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1216       rhs_offset, empty_pipeline);
1217 
1218   for (int r = 0; r < rows; r++) {
1219     for (int c = 0; c < cols; c++) {
1220       std::int32_t expected = 0;
1221       for (int d = 0; d < depth; d++) {
1222         std::int32_t lhs_val =
1223             static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1224         std::int32_t rhs_val =
1225             static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1226         expected += lhs_val * rhs_val;
1227       }
1228       Check(expected == result_raw_int32(r, c));
1229     }
1230   }
1231 
1232   // Test a pipeline with only the quantize-down stage, still returning
1233   // unclamped (but scaled) int32's
1234   OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1235   quantize_down_stage.result_offset = result_offset;
1236   quantize_down_stage.result_mult_int = result_mult_int;
1237   quantize_down_stage.result_shift = result_shift;
1238   auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1239   Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1240   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1241       &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1242       lhs_offset, rhs_offset, quantize_down_pipeline);
1243 
1244   std::int64_t sum = 0;
1245   for (int r = 0; r < rows; r++) {
1246     for (int c = 0; c < cols; c++) {
1247       std::int32_t raw = result_raw_int32(r, c);
1248       std::int32_t expected = RoundingDivideByPOT(
1249           (raw + result_offset) * result_mult_int, result_shift);
1250       Check(expected == result_quantized_down_int32(r, c));
1251       sum += expected;
1252     }
1253   }
1254   std::int64_t avg = sum / (rows * cols);
1255   // Test that the average quantized-down value falls reasonably in the
1256   // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1257   // adjusted.
1258   Check(avg >= 64 && avg <= 192);
1259 
1260   // Test the familiar default pipeline consisting of quantize-down and
1261   // clamp-and-cast-to-uint8.
1262   OutputStageSaturatingCastToUint8 saturating_cast_stage;
1263   auto quantize_down_and_saturating_cast_pipeline =
1264       std::make_tuple(quantize_down_stage, saturating_cast_stage);
1265   Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1266                                                                           cols);
1267   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1268       &context, lhs.const_map(), rhs.const_map(),
1269       &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1270       quantize_down_and_saturating_cast_pipeline);
1271 
1272   for (int r = 0; r < rows; r++) {
1273     for (int c = 0; c < cols; c++) {
1274       std::int32_t quantized = result_quantized_down_int32(r, c);
1275       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1276       Check(expected == result_quantized_down_saturated_uint8(r, c));
1277     }
1278   }
1279 
1280   // Test a variant of the familiar default pipeline consisting of quantize-down
1281   // and clamp-and-cast-to-int16.
1282   OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1283   auto quantize_down_and_saturating_cast_int16_pipeline =
1284       std::make_tuple(quantize_down_stage, saturating_cast_int16_stage);
1285   Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows,
1286                                                                           cols);
1287   GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>(
1288       &context, lhs.const_map(), rhs.const_map(),
1289       &result_quantized_down_saturated_int16, lhs_offset, rhs_offset,
1290       quantize_down_and_saturating_cast_int16_pipeline);
1291 
1292   for (int r = 0; r < rows; r++) {
1293     for (int c = 0; c < cols; c++) {
1294       std::int32_t quantized = result_quantized_down_int32(r, c);
1295       std::int16_t expected = std::min(std::max(quantized, -32768), 32767);
1296       Check(expected == result_quantized_down_saturated_int16(r, c));
1297     }
1298   }
1299 
1300 #ifdef GEMMLOWP_MSA
1301   // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8.
1302   OutputStageTruncatingCastToUint8 truncating_cast_stage;
1303   auto quantize_down_and_truncating_cast_pipeline =
1304       std::make_tuple(quantize_down_stage, truncating_cast_stage);
1305   Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8(
1306       rows, cols);
1307   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1308       &context, lhs.const_map(), rhs.const_map(),
1309       &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset,
1310       quantize_down_and_truncating_cast_pipeline);
1311 
1312   for (int r = 0; r < rows; r++) {
1313     for (int c = 0; c < cols; c++) {
1314       std::int32_t quantized = result_quantized_down_int32(r, c);
1315       std::uint8_t expected = quantized & 255;
1316       Check(expected == result_quantized_down_truncated_uint8(r, c));
1317     }
1318   }
1319 #endif
1320 
1321   // Test a bias-addition with row-vector
1322   std::vector<std::int32_t> row_vector_data(cols);
1323   std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1324                                                                          500);
1325   for (int i = 0; i < cols; i++) {
1326     row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1327   }
1328   typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1329   RowVectorMap row_vector_map(row_vector_data.data(), cols);
1330   OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1331   row_bias_addition_stage.bias_vector = row_vector_map;
1332   auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1333   Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1334   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1335       &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1336       lhs_offset, rhs_offset, row_bias_addition_pipeline);
1337   for (int r = 0; r < rows; r++) {
1338     for (int c = 0; c < cols; c++) {
1339       std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1340       Check(expected == result_of_row_bias_addition(r, c));
1341     }
1342   }
1343 
1344   // Test a bias-addition with column-vector
1345   std::vector<std::int32_t> col_vector_data(rows);
1346   for (int i = 0; i < rows; i++) {
1347     col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1348   }
1349   typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1350   ColVectorMap col_vector_map(col_vector_data.data(), rows);
1351   OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1352   col_bias_addition_stage.bias_vector = col_vector_map;
1353   auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1354   Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1355   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1356       &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1357       lhs_offset, rhs_offset, col_bias_addition_pipeline);
1358   for (int r = 0; r < rows; r++) {
1359     for (int c = 0; c < cols; c++) {
1360       std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1361       Check(expected == result_of_col_bias_addition(r, c));
1362     }
1363   }
1364 
1365   // Test a clamp
1366   OutputStageClamp clamp_stage;
1367   // Determine min and max of raw int32 accumulators
1368   std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1369   std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1370   for (int r = 0; r < rows; r++) {
1371     for (int c = 0; c < cols; c++) {
1372       raw_min = std::min(raw_min, result_raw_int32(r, c));
1373       raw_max = std::max(raw_max, result_raw_int32(r, c));
1374     }
1375   }
1376   // Pick some interesting clamp min/max bounds
1377   clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1378   clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1379   assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1380          clamp_stage.max <= raw_max);
1381   auto clamp_pipeline = std::make_tuple(clamp_stage);
1382   Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1383   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1384       &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1385       rhs_offset, clamp_pipeline);
1386   for (int r = 0; r < rows; r++) {
1387     for (int c = 0; c < cols; c++) {
1388       std::int32_t raw = result_raw_int32(r, c);
1389       std::int32_t expected =
1390           std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1391       Check(expected == result_clamped(r, c));
1392     }
1393   }
1394 
1395   // Test tanh
1396   OutputStageTanh tanh_stage;
1397   const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1398   const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1399   tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1400   tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1401   auto tanh_pipeline = std::make_tuple(tanh_stage);
1402   Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1403   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1404       &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1405       rhs_offset, tanh_pipeline);
1406   for (int r = 0; r < rows; r++) {
1407     for (int c = 0; c < cols; c++) {
1408       std::int32_t raw = result_raw_int32(r, c);
1409       double real_input =
1410           double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1411       double expected = std::tanh(real_input);
1412       std::int32_t actual_int32 = result_tanh(r, c);
1413       double actual =
1414           double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1415       Check(std::abs(expected - actual) < 2e-4);
1416     }
1417   }
1418 
1419   // Test a pipeline with bias and clamp
1420   auto bias_clamp_pipeline =
1421       std::make_tuple(col_bias_addition_stage, clamp_stage);
1422   Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1423   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1424       &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1425       lhs_offset, rhs_offset, bias_clamp_pipeline);
1426   for (int r = 0; r < rows; r++) {
1427     for (int c = 0; c < cols; c++) {
1428       std::int32_t raw = result_raw_int32(r, c);
1429       std::int32_t biased = raw + col_vector_data[r];
1430       std::int32_t expected =
1431           std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1432       Check(expected == result_biased_clamped(r, c));
1433     }
1434   }
1435 
1436   // Test a full pipeline with bias and clamp and quantization down to 8bit
1437   // result
1438   auto bias_clamp_quantize_cast_pipeline =
1439       std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1440                       saturating_cast_stage);
1441   Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1442       rows, cols);
1443   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1444       &context, lhs.const_map(), rhs.const_map(),
1445       &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1446       bias_clamp_quantize_cast_pipeline);
1447   for (int r = 0; r < rows; r++) {
1448     for (int c = 0; c < cols; c++) {
1449       std::int32_t quantized = RoundingDivideByPOT(
1450           (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1451           result_shift);
1452       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1453       Check(expected == result_biased_clamped_quantized_casted(r, c));
1454     }
1455   }
1456 
1457   // Test a pipeline with the fixed-point-multiplier variant stage for the
1458   // quantizing down of 32bit accumulators.
1459   //
1460   // First, figure appropriate fixedpoint multiplier and shift values.
1461   std::int32_t result_fixedpoint_multiplier = result_mult_int;
1462   std::int32_t result_fixedpoint_shift = result_shift;
1463   Check(result_mult_int > 0);
1464   Check(result_shift > 0);
1465   result_fixedpoint_multiplier = result_mult_int;
1466   result_fixedpoint_shift = result_shift - 31;
1467   while (result_fixedpoint_multiplier < (1 << 30)) {
1468     result_fixedpoint_multiplier <<= 1;
1469     result_fixedpoint_shift++;
1470   }
1471   Check(result_fixedpoint_shift >= 0);
1472   // Now test OutputStageQuantizeDownInt32ByFixedPoint
1473   OutputStageQuantizeDownInt32ByFixedPoint
1474       quantize_down_by_fixedpoint_stage;
1475   quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1476       static_cast<std::int32_t>(
1477           round(static_cast<double>(result_offset * result_mult_int) /
1478                 (1 << result_shift)));
1479   quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1480       result_fixedpoint_multiplier;
1481   quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1482   auto quantize_down_by_fixedpoint_pipeline =
1483       std::make_tuple(quantize_down_by_fixedpoint_stage);
1484   Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1485       rows, cols);
1486   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1487       &context, lhs.const_map(), rhs.const_map(),
1488       &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1489       quantize_down_by_fixedpoint_pipeline);
1490 
1491   for (int r = 0; r < rows; r++) {
1492     for (int c = 0; c < cols; c++) {
1493       const std::int32_t actual =
1494           result_quantized_down_by_fixedpoint_int32(r, c);
1495       const std::int32_t raw = result_raw_int32(r, c);
1496       const std::int32_t expected =
1497           quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1498           RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1499                                   raw, result_fixedpoint_multiplier),
1500                               result_fixedpoint_shift);
1501       Check(actual == expected);
1502     }
1503   }
1504 
1505   // Test OutputStageScaleInt32ByFixedPointAndExponent
1506   for (int exponent = -2; exponent <= 2; exponent++) {
1507     OutputStageScaleInt32ByFixedPointAndExponent
1508         scale_by_fixedpoint_and_exponent_stage;
1509     scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift =
1510         static_cast<std::int32_t>(round(static_cast<double>(
1511             result_offset * result_mult_int * std::pow(2.0, exponent))));
1512     scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier =
1513         result_fixedpoint_multiplier;
1514     scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent;
1515     auto scale_by_fixedpoint_and_exponent_pipeline =
1516         std::make_tuple(scale_by_fixedpoint_and_exponent_stage);
1517     Matrix<std::int32_t, ResultOrder>
1518         result_scaled_by_fixedpoint_and_exponent_int32(rows, cols);
1519     GemmWithOutputPipeline<std::uint8_t, std::int32_t,
1520                            DefaultL8R8BitDepthParams>(
1521         &context, lhs.const_map(), rhs.const_map(),
1522         &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset,
1523         scale_by_fixedpoint_and_exponent_pipeline);
1524 
1525     for (int r = 0; r < rows; r++) {
1526       for (int c = 0; c < cols; c++) {
1527         const std::int32_t actual =
1528             result_scaled_by_fixedpoint_and_exponent_int32(r, c);
1529         const std::int32_t raw = result_raw_int32(r, c);
1530         int left_shift = std::max(0, exponent);
1531         int right_shift = std::max(0, -exponent);
1532         const std::int32_t expected =
1533             scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift +
1534             RoundingDivideByPOT(
1535                 SaturatingRoundingDoublingHighMul((1 << left_shift) * raw,
1536                                                   result_fixedpoint_multiplier),
1537                 right_shift);
1538         Check(actual == expected);
1539       }
1540     }
1541   }
1542 
1543   // Test the variant of the familiar default pipeline consisting of
1544   // quantize-down and
1545   // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1546   // downscaling.
1547   auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1548       std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1549   Matrix<std::uint8_t, ResultOrder>
1550       result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1551   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1552       &context, lhs.const_map(), rhs.const_map(),
1553       &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1554       rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1555 
1556   for (int r = 0; r < rows; r++) {
1557     for (int c = 0; c < cols; c++) {
1558       std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1559       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1560       Check(expected ==
1561             result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1562     }
1563   }
1564 
1565   printf("TestOutputStages: PASS with ResultOrder=%s\n",
1566          OrderName(ResultOrder));
1567 }
1568 
1569 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1570 template <typename BitDepthParams>
TestExhaustively()1571 void TestExhaustively() {
1572   GemmContext context;
1573 
1574   // Test the internal GEMM interfaces
1575   test_gemm<
1576       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1577                               std::uint8_t, BitDepthParams>>(&context);
1578 
1579   test_gemm<
1580       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1581                              std::uint8_t, BitDepthParams>>(&context);
1582 
1583   // Test the public GEMM interfaces
1584   test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1585 
1586   // Test GEMV cases (internal interfaces)
1587   test_gemv<
1588       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1589                               std::uint8_t, BitDepthParams>>(&context);
1590 
1591   test_gemv<
1592       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1593                              std::uint8_t, BitDepthParams>>(&context);
1594 
1595   // Test GEMV cases (public interfaces)
1596   test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1597 }
1598 
1599 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1600 void TestExhaustivelyEightBitIntGemm() {
1601   GemmContext context;
1602   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1603   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1604   test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1605 }
1606 
TestKernels()1607 void TestKernels() {
1608   GemmContext context;
1609 
1610   // Test specific kernels with various different formats,
1611   // to exercises corner cases especially in the packing code.
1612   test_gemm_kernel<
1613       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1614                                    KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1615       &context);
1616 
1617   test_gemm_kernel<
1618       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1619                                    KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1620       &context);
1621 
1622   test_gemm_kernel<
1623       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1624                                    KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1625       &context);
1626 
1627   test_gemm_kernel<ReferenceKernel<KernelFormat<
1628       KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1629       KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1630 
1631   test_gemm_kernel<ReferenceKernel<KernelFormat<
1632       KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1633       KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1634 
1635   test_gemm_kernel<ReferenceKernel<KernelFormat<
1636       KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1637       KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1638 
1639   test_gemm_kernel<ReferenceKernel<KernelFormat<
1640       KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1641       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1642 
1643   test_gemm_kernel<ReferenceKernel<KernelFormat<
1644       KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1645       KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1646 
1647   test_gemm_kernel<ReferenceKernel<KernelFormat<
1648       KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1649       KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1650 }
1651 
1652 #endif  // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1653 
1654 template <typename BitDepthParams>
TestOutputStages()1655 void TestOutputStages() {
1656   // Test non-default output pipelines with various combinations of
1657   // output stages.
1658   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1659   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1660   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1661                                                        14);
1662   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1663                                                        14);
1664 }
1665 
test()1666 void test() {
1667 #ifdef GEMMLOWP_TEST_PROFILE
1668   RegisterCurrentThreadForProfiling();
1669   StartProfiling();
1670 #endif
1671 
1672   // Run a first quick test against hand-calculated data.
1673   TestWithSmallData();
1674 
1675 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1676   TestExhaustively<DefaultL8R8BitDepthParams>();
1677   TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1678   TestExhaustively<DefaultL7R5BitDepthParams>();  // legacy, same as L8R8
1679   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1680   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1681   TestKernels();
1682 #endif
1683 
1684   // Run against actual data from a network evaluation.
1685   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1686   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1687 
1688   // Test non-default output pipelines with various combinations of
1689   // output stages.
1690   TestOutputStages<DefaultL8R8BitDepthParams>();
1691   TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1692 
1693   // Test per channel quantization.
1694   TestWithSmallDataPerChannelQuantization();
1695   TestWithLargeDataPerChannelQuantization();
1696   TestMultithreadedPerChannelQuantization();
1697 #ifdef GEMMLOWP_TEST_PROFILE
1698   FinishProfiling();
1699 #endif
1700 
1701   std::cerr << "All tests passed." << std::endl;
1702 
1703   // We have been testing the eight_bit_int_gemm, so we should free its
1704   // persistent
1705   // resources now to avoid having leak-checking tools report leaks.
1706   eight_bit_int_gemm::FreePersistentResources();
1707 }
1708 
1709 }  // end namespace gemmlowp
1710 
1711 // For iOS, we need to define our own main(), so skip it here.
1712 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1713 int main() { gemmlowp::test(); }
1714 #endif
1715