• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 
17 #include <array>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28 
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32 
33 namespace gemmlowp {
34 
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36                               bool transpose_c, int m, int n, int k,
37                               const std::uint8_t* a, std::int32_t a_offset,
38                               int lda, const std::uint8_t* b,
39                               std::int32_t b_offset, int ldb, std::uint8_t* c,
40                               std::int32_t c_offset, std::int32_t c_mult_int,
41                               std::int32_t c_shift, int ldc) {
42   ScopedProfilingLabel("ReferenceEightBitIntGemm");
43   assert((c_shift >= 0) && (c_shift <= 32));
44 
45   assert(a != nullptr);
46   assert(b != nullptr);
47   assert(c != nullptr);
48 
49   int a_i_stride;
50   int a_l_stride;
51   if (transpose_a) {
52     a_i_stride = lda;
53     a_l_stride = 1;
54   } else {
55     a_i_stride = 1;
56     a_l_stride = lda;
57   }
58   int b_j_stride;
59   int b_l_stride;
60   if (transpose_b) {
61     b_j_stride = 1;
62     b_l_stride = ldb;
63   } else {
64     b_j_stride = ldb;
65     b_l_stride = 1;
66   }
67   int c_i_stride;
68   int c_j_stride;
69   if (transpose_c) {
70     c_i_stride = ldc;
71     c_j_stride = 1;
72   } else {
73     c_i_stride = 1;
74     c_j_stride = ldc;
75   }
76   int i, j, l;
77 
78   const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
79 
80   for (j = 0; j < n; j++) {
81     for (i = 0; i < m; i++) {
82       std::int32_t total = 0;
83       for (l = 0; l < k; l++) {
84         const int a_index = i * a_i_stride + l * a_l_stride;
85         const std::uint8_t a_as_byte = a[a_index];
86         const std::int32_t a_as_int =
87             static_cast<std::int32_t>(a_as_byte) + a_offset;
88         const int b_index = j * b_j_stride + l * b_l_stride;
89         const std::uint8_t b_as_byte = b[b_index];
90         const std::int32_t b_as_int =
91             static_cast<std::int32_t>(b_as_byte) + b_offset;
92         const std::int32_t mult_as_int = a_as_int * b_as_int;
93         total += mult_as_int;
94       }
95       std::int32_t output =
96           (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
97       if (output > 255) {
98         output = 255;
99       }
100       if (output < 0) {
101         output = 0;
102       }
103       const int c_index = i * c_i_stride + j * c_j_stride;
104       c[c_index] = static_cast<std::uint8_t>(output);
105     }
106   }
107 }
108 
109 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
110 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
111 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
112 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
113 
114 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
115 // interface, so we can use the same testing code to test all of them
116 
117 template <typename Kernel, typename Scalar, typename tBitDepthParams>
118 struct SingleThreadGemmWrapper {
119   typedef tBitDepthParams BitDepthParams;
120 
Namegemmlowp::SingleThreadGemmWrapper121   static const char* Name() {
122     static char buf[256];
123     snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
124     return buf;
125   }
126 
127   typedef SingleThreadGemmContext Context;
128 
129   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper130   static bool Gemm(Context* context,
131                    const MatrixMap<const Scalar, LhsOrder>& lhs,
132                    const MatrixMap<const Scalar, RhsOrder>& rhs,
133                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
134                    int rhs_offset, int result_offset, int result_mult_int,
135                    int result_shift) {
136     ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
137     const int rows = lhs.rows();
138     const int cols = rhs.cols();
139     if (rows < cols) {
140       // SingleThreadGemm is never called with rows < cols.
141       // That case is handled earlier.
142       return false;
143     }
144     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
145     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
146     SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
147                      LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
148                      OffsetRowDup>(
149         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
150         rhs_offset_vector,
151         MakeStandardOutputPipeline(result_offset, result_mult_int,
152                                    result_shift));
153     return true;
154   }
155 };
156 
157 template <typename Kernel, typename Scalar, typename tBitDepthParams>
158 struct MultiThreadGemmWrapper {
159   typedef tBitDepthParams BitDepthParams;
160 
Namegemmlowp::MultiThreadGemmWrapper161   static const char* Name() {
162     static char buf[256];
163     snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
164     return buf;
165   }
166 
167   typedef MultiThreadGemmContext Context;
168 
169   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper170   static bool Gemm(Context* context,
171                    const MatrixMap<const Scalar, LhsOrder>& lhs,
172                    const MatrixMap<const Scalar, RhsOrder>& rhs,
173                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
174                    int rhs_offset, int result_offset, int result_mult_int,
175                    int result_shift) {
176     ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
177     context->set_max_num_threads(0);
178     const int rows = lhs.rows();
179     const int cols = rhs.cols();
180     if (rows < cols) {
181       // SingleThreadGemm is never called with rows < cols.
182       // That case is handled earlier.
183       return false;
184     }
185     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
186     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
187     MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
188                     LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
189                     OffsetRowDup>(
190         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
191         rhs_offset_vector,
192         MakeStandardOutputPipeline(result_offset, result_mult_int,
193                                    result_shift));
194     return true;
195   }
196 };
197 
198 template <typename Scalar, typename tBitDepthParams>
199 struct PublicGemmWrapper {
200   typedef tBitDepthParams BitDepthParams;
201 
Namegemmlowp::PublicGemmWrapper202   static const char* Name() { return "public Gemm"; }
203 
204   typedef GemmContext Context;
205 
206   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper207   static bool Gemm(Context* context,
208                    const MatrixMap<const Scalar, LhsOrder>& lhs,
209                    const MatrixMap<const Scalar, RhsOrder>& rhs,
210                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
211                    int rhs_offset, int result_offset, int result_mult_int,
212                    int result_shift) {
213     ScopedProfilingLabel("PublicGemmWrapper::Gemm");
214     gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
215                    ResultOrder>(context, lhs, rhs, result, lhs_offset,
216                                 rhs_offset, result_offset, result_mult_int,
217                                 result_shift);
218     return true;
219   }
220 };
221 
222 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
223 struct BitDepthParamsForSettings {};
224 
225 template <>
226 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
227     : DefaultL8R8BitDepthParams {};
228 
229 template <>
230 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
231     : DefaultL7R5BitDepthParams {};
232 
233 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
234 struct EightBitIntGemmWrapper {
235   typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
236 
Namegemmlowp::EightBitIntGemmWrapper237   static const char* Name() { return "EightBitIntGemm"; }
238 
239   typedef void Context;
240 
241   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper242   static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
243                    const MatrixMap<const Scalar, RhsOrder>& rhs,
244                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
245                    int rhs_offset, int result_offset, int result_mult_int,
246                    int result_shift) {
247     ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
248     const bool transpose_c = ResultOrder == MapOrder::RowMajor;
249     const bool transpose_a = LhsOrder == MapOrder::RowMajor;
250     const bool transpose_b = RhsOrder == MapOrder::RowMajor;
251     eight_bit_int_gemm::EightBitIntGemm(
252         transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
253         lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
254         rhs_offset, rhs.stride(), result->data(), result_offset,
255         result_mult_int, result_shift, result->stride(), BitDepth);
256     return true;
257   }
258 };
259 
260 template <typename Scalar>
261 struct ReferenceEightBitIntGemmWrapper {
262   typedef DefaultL8R8BitDepthParams BitDepthParams;
263 
Namegemmlowp::ReferenceEightBitIntGemmWrapper264   static const char* Name() { return "ReferenceEightBitIntGemm"; }
265 
266   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper267   static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
268                    const MatrixMap<const Scalar, LhsOrder>& lhs,
269                    const MatrixMap<const Scalar, RhsOrder>& rhs,
270                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
271                    int rhs_offset, int result_offset, int result_mult_int,
272                    int result_shift) {
273     ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
274     ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
275                              rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
276                              lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
277                              result->data(), result_offset, result_mult_int,
278                              result_shift, result->stride());
279     return true;
280   }
281 };
282 
OrderName(MapOrder order)283 const char* OrderName(MapOrder order) {
284   return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
285 }
286 
287 struct ResultStats {
ResultStatsgemmlowp::ResultStats288   ResultStats()
289       : count(0),
290         med_val(0),
291         mean_signed_diff(0),
292         med_signed_diff(0),
293         med_unsigned_diff(0),
294         max_unsigned_diff(0) {}
295 
296   int count;
297   int med_val;
298   float mean_signed_diff;
299   int med_signed_diff;
300   int med_unsigned_diff;
301   int max_unsigned_diff;
302 
303   std::vector<int> count_diff_by_pot_slice;
304 };
305 
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)306 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
307                     size_t count, ResultStats* stats) {
308   ScopedProfilingLabel("GetResultStats");
309   std::vector<std::uint8_t> results;
310   std::vector<std::int16_t> signed_diffs;
311   std::vector<std::uint8_t> unsigned_diffs;
312   std::int64_t signed_diffs_sum = 0;
313   for (size_t i = 0; i < count; i++) {
314     results.push_back(actual[i]);
315     std::int16_t signed_diff = actual[i] - expected[i];
316     signed_diffs.push_back(signed_diff);
317     unsigned_diffs.push_back(std::abs(signed_diff));
318     signed_diffs_sum += signed_diff;
319   }
320 
321   std::sort(results.begin(), results.end());
322   std::sort(signed_diffs.begin(), signed_diffs.end());
323   std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
324 
325   const size_t middle = count / 2;
326 
327   stats->count = count;
328   stats->med_val = results[middle];
329   stats->mean_signed_diff = float(signed_diffs_sum) / count;
330   stats->med_signed_diff = signed_diffs[middle];
331   stats->med_unsigned_diff = unsigned_diffs[middle];
332   stats->max_unsigned_diff = unsigned_diffs.back();
333 
334   // Size 9 for 9 different POT values: 2^0, ..., 2^8
335   stats->count_diff_by_pot_slice.resize(9);
336   auto cur = unsigned_diffs.begin();
337   size_t checksum = 0;
338   for (int exponent = 0; exponent < 9; exponent++) {
339     int pot = 1 << exponent;
340     auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
341     checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
342     cur = next;
343   }
344   assert(checksum == count);
345 }
346 
347 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds348   ResultStatsBounds()
349       : mean_signed_diff(0),
350         med_signed_diff(0),
351         med_unsigned_diff(0),
352         max_unsigned_diff(0) {}
353 
354   float mean_signed_diff;
355   int med_signed_diff;
356   int med_unsigned_diff;
357   int max_unsigned_diff;
358 };
359 
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)360 bool CheckResultStatsBounds(const ResultStats& stats,
361                             const ResultStatsBounds& bounds) {
362   return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
363          stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
364          std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
365          std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
366 }
367 
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)368 void ReportResultStats(const ResultStats& stats,
369                        const ResultStatsBounds& bounds) {
370   printf("    number of matrix entries: %d\n", stats.count);
371   printf("    median value: %d\n", stats.med_val);
372   printf("    median unsigned diff: %d (tolerating %d)\n",
373          stats.med_unsigned_diff, bounds.med_unsigned_diff);
374   printf("    max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
375          bounds.max_unsigned_diff);
376   printf("    median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
377          bounds.med_signed_diff);
378   printf("    mean signed diff: %.3g (tolerating %.3g)\n",
379          stats.mean_signed_diff, bounds.mean_signed_diff);
380 
381   printf("No error: %.2f %% of entries\n",
382          100.f * stats.count_diff_by_pot_slice[0] / stats.count);
383   for (int exponent = 1; exponent < 9; exponent++) {
384     printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
385            (1 << exponent) - 1,
386            100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
387   }
388 }
389 
390 // Our approach to choosing result_shift values for testing, is bisection.
391 // This function takes an interval, [result_shift_min .. result_shift_max].
392 // If too much saturation occurred in either direction, it bisects accordingly,
393 // recursing until the interval contains only one value.
394 // The primary reason why we prefer this over computing optimal shift values,
395 // is that we actually want to exercise some saturation, as there is nontrivial
396 // code handling that in gemmlowp.
397 // Secondarily, this is faster than computing optimal shifts, since in 90% of
398 // cases the first-tried shift value 16 turns out to be good enough.
399 template <typename GemmWrapper, typename LhsType, typename RhsType,
400           typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)401 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
402                     const RhsType& rhs, ResultType* result, int lhs_offset,
403                     int rhs_offset, int result_offset, int result_mult_int,
404                     int result_shift_min, int result_shift_max) {
405   const int rows = lhs.rows();
406   const int cols = rhs.cols();
407   Check(lhs.cols() == rhs.rows());
408   const int depth = lhs.cols();
409 
410   const int result_shift = (result_shift_min + result_shift_max) / 2;
411 
412   if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
413                          &result->map(), lhs_offset, rhs_offset, result_offset,
414                          result_mult_int, result_shift)) {
415     // Internal GEMM functions are not required to handle all cases
416     // (e.g. rows < cols) as these are supposed to have been handled
417     // ahead of them. Their test wrappers return false in that case.
418     return;
419   }
420 
421   typedef typename ResultType::Scalar Scalar;
422   static const MapOrder kLhsOrder = LhsType::kOrder;
423   static const MapOrder kRhsOrder = RhsType::kOrder;
424   static const MapOrder kResultOrder = ResultType::kOrder;
425   ResultType ref_result(rows, cols);
426   const bool transpose_c = kResultOrder == MapOrder::RowMajor;
427   const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
428   const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
429   ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
430       transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
431       &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
432       result_shift);
433 
434   typedef typename GemmWrapper::BitDepthParams BitDepthParams;
435 
436   ResultStats stats;
437   GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
438 
439   // Adjust shifts until we get meaningful results
440   int new_result_shift_min = result_shift_min;
441   int new_result_shift_max = result_shift_max;
442   bool retry = false;
443 
444   if (stats.med_val < 32) {
445     new_result_shift_max = (result_shift_min + result_shift_max) / 2;
446     retry = true;
447   }
448 
449   if (stats.med_val > 224) {
450     new_result_shift_min = (result_shift_min + result_shift_max) / 2;
451     retry = true;
452   }
453 
454   if (retry) {
455     if (result_shift_min != result_shift_max) {
456       test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
457                                   rhs_offset, result_offset, result_mult_int,
458                                   new_result_shift_min, new_result_shift_max);
459     }
460     return;
461   }
462 
463   ResultStatsBounds bounds;
464 
465   // Check results
466   const bool good = CheckResultStatsBounds(stats, bounds);
467 
468   printf(
469       "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
470       good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
471       OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
472       lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
473 
474   if (!good) {
475     ReportResultStats(stats, bounds);
476 
477     int bad_coeffs_printed = 0;
478     for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
479       for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
480         if (ref_result(r, c) != (*result)(r, c)) {
481           printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
482                  ref_result(r, c), (*result)(r, c));
483           bad_coeffs_printed++;
484         }
485       }
486     }
487   }
488 
489   Check(good);
490 }
491 
492 template <typename GemmWrapper, typename LhsType, typename RhsType,
493           typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)494 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
495                const RhsType& rhs, ResultType* result, int lhs_offset,
496                int rhs_offset, int result_offset, int result_mult_int) {
497   test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
498                               result_offset, result_mult_int, 0, 32);
499 }
500 
501 enum class WhatParamsToTest {
502   All,
503   OnlyGenericCase,
504 };
505 
506 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
507           MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)508 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
509                int cols, WhatParamsToTest params_to_test) {
510   typedef std::uint8_t Scalar;
511   typedef Matrix<Scalar, LhsOrder> LhsType;
512   using BitDepthParams = typename GemmWrapper::BitDepthParams;
513   LhsType lhs(rows, depth);
514   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
515   typedef Matrix<Scalar, RhsOrder> RhsType;
516   RhsType rhs(depth, cols);
517   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
518   typedef Matrix<Scalar, ResultOrder> ResultType;
519   ResultType result(rows, cols);
520   MakeZero(&result);
521 
522   if (params_to_test == WhatParamsToTest::All) {
523     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
524     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
525     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
526     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
527     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
528     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
529     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
530   }
531   test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
532 }
533 
534 enum class WhatOrdersToTest { All, OnlyRCC };
535 
536 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)537 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
538                int cols, WhatParamsToTest params_to_test,
539                WhatOrdersToTest orders_to_test) {
540 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder)         \
541   do {                                                             \
542     test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
543               MapOrder::ResultOrder>(context, rows, depth, cols,   \
544                                      params_to_test);              \
545   } while (false)
546 
547   if (orders_to_test == WhatOrdersToTest::All) {
548     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
549     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
550     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
551     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
552 
553     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
554     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
555     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
556     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
557   } else {
558     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
559   }
560 
561 #undef GEMMLOWP_ONE_TEST
562 }
563 
564 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)565 void test_gemm_kernel(MultiThreadGemmContext* context) {
566   typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
567                                  DefaultL8R8BitDepthParams>
568       GemmWrapper;
569   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
570                          WhatOrdersToTest::OnlyRCC);
571   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
572                          WhatOrdersToTest::OnlyRCC);
573   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
574                          WhatOrdersToTest::OnlyRCC);
575   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
576                          WhatOrdersToTest::OnlyRCC);
577   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
578                          WhatOrdersToTest::OnlyRCC);
579   test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
580                          WhatOrdersToTest::OnlyRCC);
581   test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
582                          WhatOrdersToTest::OnlyRCC);
583   test_gemm<GemmWrapper>(context, 200, 200, 200,
584                          WhatParamsToTest::OnlyGenericCase,
585                          WhatOrdersToTest::All);
586   test_gemm<GemmWrapper>(context, 50, 5000, 50,
587                          WhatParamsToTest::OnlyGenericCase,
588                          WhatOrdersToTest::OnlyRCC);
589 }
590 
591 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)592 void test_gemm(typename GemmWrapper::Context* context) {
593   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
594                          WhatOrdersToTest::OnlyRCC);
595   test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
596                          WhatOrdersToTest::OnlyRCC);
597   test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
598                          WhatOrdersToTest::OnlyRCC);
599   test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
600                          WhatOrdersToTest::OnlyRCC);
601   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
602                          WhatOrdersToTest::OnlyRCC);
603   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
604                          WhatOrdersToTest::OnlyRCC);
605   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
606                          WhatOrdersToTest::OnlyRCC);
607   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
608                          WhatOrdersToTest::OnlyRCC);
609   test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
610                          WhatOrdersToTest::OnlyRCC);
611   test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
612                          WhatOrdersToTest::OnlyRCC);
613   test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
614                          WhatOrdersToTest::OnlyRCC);
615   test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
616                          WhatOrdersToTest::OnlyRCC);
617   test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
618                          WhatOrdersToTest::All);
619   test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
620                          WhatOrdersToTest::OnlyRCC);
621   test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
622                          WhatOrdersToTest::OnlyRCC);
623   test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
624                          WhatOrdersToTest::All);
625   test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
626                          WhatOrdersToTest::OnlyRCC);
627 
628   test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
629                          WhatOrdersToTest::OnlyRCC);
630   test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
631                          WhatOrdersToTest::OnlyRCC);
632   test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
633                          WhatOrdersToTest::OnlyRCC);
634   test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
635                          WhatOrdersToTest::OnlyRCC);
636   test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
637                          WhatOrdersToTest::OnlyRCC);
638   test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
639                          WhatOrdersToTest::OnlyRCC);
640 
641   test_gemm<GemmWrapper>(context, 512, 512, 512,
642                          WhatParamsToTest::OnlyGenericCase,
643                          WhatOrdersToTest::OnlyRCC);
644   test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
645                          WhatParamsToTest::OnlyGenericCase,
646                          WhatOrdersToTest::OnlyRCC);
647   test_gemm<GemmWrapper>(context, 567, 2345, 123,
648                          WhatParamsToTest::OnlyGenericCase,
649                          WhatOrdersToTest::OnlyRCC);
650   test_gemm<GemmWrapper>(context, 100, 5000, 100,
651                          WhatParamsToTest::OnlyGenericCase,
652                          WhatOrdersToTest::OnlyRCC);
653   test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
654                          WhatOrdersToTest::OnlyRCC);
655   test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
656                          WhatOrdersToTest::OnlyRCC);
657   test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
658                          WhatOrdersToTest::OnlyRCC);
659   test_gemm<GemmWrapper>(context, 1, 1000, 1000,
660                          WhatParamsToTest::OnlyGenericCase,
661                          WhatOrdersToTest::OnlyRCC);
662   test_gemm<GemmWrapper>(context, 1000, 1, 1000,
663                          WhatParamsToTest::OnlyGenericCase,
664                          WhatOrdersToTest::OnlyRCC);
665   test_gemm<GemmWrapper>(context, 1000, 1000, 1,
666                          WhatParamsToTest::OnlyGenericCase,
667                          WhatOrdersToTest::OnlyRCC);
668   test_gemm<GemmWrapper>(context, 777, 3456, 1,
669                          WhatParamsToTest::OnlyGenericCase,
670                          WhatOrdersToTest::OnlyRCC);
671   test_gemm<GemmWrapper>(context, 4567, 555, 1,
672                          WhatParamsToTest::OnlyGenericCase,
673                          WhatOrdersToTest::OnlyRCC);
674 
675   // Test all storage orders
676   test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
677                          WhatOrdersToTest::All);
678   test_gemm<GemmWrapper>(context, 300, 400, 500,
679                          WhatParamsToTest::OnlyGenericCase,
680                          WhatOrdersToTest::All);
681 }
682 
683 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)684 void test_gemv(typename GemmWrapper::Context* context) {
685   test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
686                          WhatOrdersToTest::OnlyRCC);
687   test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
688                          WhatOrdersToTest::OnlyRCC);
689   test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
690                          WhatOrdersToTest::OnlyRCC);
691   test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
692                          WhatOrdersToTest::OnlyRCC);
693   test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
694                          WhatOrdersToTest::OnlyRCC);
695   test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
696                          WhatOrdersToTest::OnlyRCC);
697   test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
698                          WhatOrdersToTest::OnlyRCC);
699   test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
700                          WhatOrdersToTest::OnlyRCC);
701   test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
702                          WhatOrdersToTest::OnlyRCC);
703   test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
704                          WhatOrdersToTest::OnlyRCC);
705   test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
706                          WhatOrdersToTest::OnlyRCC);
707   test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
708                          WhatOrdersToTest::OnlyRCC);
709 
710   // Test all storage orders
711   test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
712                          WhatOrdersToTest::All);
713   test_gemm<GemmWrapper>(context, 300, 400, 1,
714                          WhatParamsToTest::OnlyGenericCase,
715                          WhatOrdersToTest::All);
716 }
717 
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)718 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
719   switch (b) {
720     case eight_bit_int_gemm::BitDepthSetting::A8B8:
721       return "Lhs: 8 bit, Rhs: 8 bit";
722     case eight_bit_int_gemm::BitDepthSetting::A5B7:
723       return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
724     default:
725       abort();
726       return nullptr;
727   }
728 }
729 
730 // Runs a small set of hand-picked data for per-channel quantized data.
731 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
732 // image.
TestWithSmallDataPerChannelQuantization()733 void TestWithSmallDataPerChannelQuantization() {
734   const int m = 2;
735   const int n = 9;
736   const int k = 12;
737 
738   // 12 x 2, columnwise.
739   const std::uint8_t a_data[] = {0,  0,   0,   0,   0,  0,   0,   0,
740                                  0,  255, 255, 255, 64, 64,  64,  64,
741                                  64, 64,  0,   0,   0,  255, 255, 255};
742   const int lda = k;
743   int a_offset[] = {0, -64};
744   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
745   const OffsetColMap lhs_offset(a_offset, m);
746 
747   // 12 x 9, columnwise.
748   const std::uint8_t b_data[] = {
749       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,
750       0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   127,
751       127, 127, 0,   0,   0,   127, 127, 127, 0,   0,   0,   255, 255, 255,
752       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,
753       0,   0,   0,   0,   0,   0,   0,   127, 127, 127, 0,   0,   0,   127,
754       127, 127, 0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127,
755       0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127, 0,   0,
756       0,   127, 127, 127, 127, 127, 127, 127, 127, 127};
757   const int ldb = k;
758   int b_offset = -127;
759   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
760   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
761 
762   // 2 x 9, columnwise.
763   const std::uint8_t expected_c_data[] = {255, 255, 0,   0,   127, 159,
764                                           0,   64,  0,   64,  127, 159,
765                                           127, 127, 127, 127, 127, 127};
766   const int ldc = m;
767   int c_offset[] = {97155, 97346};
768   int c_mult_int[] = {2741, 2741};
769   const int c_shift = 21;
770 
771   const int c_count = m * n;
772   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
773   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
774                                                      ldc);
775   const OffsetColMap result_offset(c_offset, m);
776   const OffsetColMap result_mult_int(c_mult_int, m);
777   const int result_shift = c_shift;
778 
779   GemmContext gemm_context;
780   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
781       result_offset, result_mult_int, result_shift);
782   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
783                            DefaultL8R8BitDepthParams>(
784       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
785       output_pipeline);
786 
787   ResultStats stats;
788   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
789 
790   ResultStatsBounds bounds;
791   const bool good = CheckResultStatsBounds(stats, bounds);
792   printf("TestWithSmallDataPerChannelQuantization: %s\n",
793          good ? "PASS" : "FAIL");
794   ReportResultStats(stats, bounds);
795   Check(good);
796 }
797 
798 // Runs a larger set of hand-picked data for per-channel quantized data.
799 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
800 // image.  Right now, I have 7 different filters and 15 copies of the first
801 // filter to make sure NEON code path that processes 16 rows at a time is
802 // covered.
TestWithLargeDataPerChannelQuantization()803 void TestWithLargeDataPerChannelQuantization() {
804   const int m = 22;
805   const int n = 25;
806   const int k = 27;
807 
808   // 27 x 22, column-wise.
809   const std::uint8_t a_data[] = {
810       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
811       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
812       0,   0,   0,   0,   0,   0,   127, 127, 127, 255, 255, 255, 127, 127, 127,
813       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   127, 127, 127,
814       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
815       127, 127, 127, 0,   0,   0,   51,  51,  51,  51,  51,  51,  51,  51,  51,
816       0,   0,   0,   255, 255, 255, 0,   0,   0,   51,  51,  51,  51,  51,  51,
817       51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,  51,  51,  51,
818       255, 255, 255, 51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,
819       0,   0,   0,   64,  64,  64,  0,   0,   0,   64,  64,  64,  255, 255, 255,
820       64,  64,  64,  0,   0,   0,   64,  64,  64,  0,   0,   0,   36,  36,  36,
821       0,   0,   0,   36,  36,  36,  0,   0,   0,   255, 255, 255, 0,   0,   0,
822       36,  36,  36,  0,   0,   0,   36,  36,  36,  0,   0,   0,   0,   0,   0,
823       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
824       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
825       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
826       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
827       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
828       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
829       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
830       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
831       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
832       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
833       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
834       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
835       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
836       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
837       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
838       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
839       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
840       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
841       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
842       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
843       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
844       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
845       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
846       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
847       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
848       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
849       0,   0,   0,   0,   0,   0,   0,   0,   0,
850   };
851   const int lda = k;
852   int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
853                     0, 0, 0, 0,   0,   0, 0,   0, 0, 0, 0};
854   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
855   const OffsetColMap lhs_offset(a_offset, m);
856 
857   // 27 x 25, column-wise.
858   const std::uint8_t b_data[] = {
859       127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
860       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
861       127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
863       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
864       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
865       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
866       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
867       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
868       119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
869       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
870       127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
871       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
873       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
874       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
876       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
877       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
878       119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
879       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
880       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
882       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
883       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
885       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
886       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
887       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
888       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
889       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
890       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
891       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
892       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
894       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
895       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
897       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
898       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
899       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
900       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
901       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
902       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
903       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
904       119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
905       127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
906       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
907       127, 127, 127};
908   const int ldb = k;
909   int b_offset = -127;
910   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
911   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
912 
913   // 22 x 25, column-wise.
914   const std::uint8_t expected_c_data[] = {
915       7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,
916       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,
917       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
918       7,   37,  87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,
919       7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,
920       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
921       37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
922       7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,
923       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
924       87,  87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
925       7,   7,   7,   7,   7,   7,   71,  87,  45,  41,  77,  7,   7,   7,   7,
926       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,
927       87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
928       7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,
929       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,
930       23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
931       7,   7,   7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,
932       7,   7,   7,   7,   7,   7,   7,   7,   7,   255, 135, 135, 255, 255, 143,
933       255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
934       255, 7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,   7,
935       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,
936       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
937       7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,
938       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,
939       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
940       7,   71,  87,  45,  41,  77,  7,   7,   7,   7,   7,   7,   7,   7,   7,
941       7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,   7,
942       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
943       7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
944       7,   7,   7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,
945       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
946       87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
947       7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,   7,   7,
948       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,
949       67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
950       7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,
951       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   99,  99,  99,  99,  99,
952       99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,
953       99,  99,  111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
954       111, 111, 111, 111, 111, 111, 111, 111, 111,
955   };
956   const int ldc = m;
957   int c_offset[] = {
958       6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
959       6477, 6477,  6477,  6477, 6477, 6477,  6477, 6477, 6477, 6477, 6477,
960   };
961   int c_mult_int[] = {
962       41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
963       41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
964       41121, 41121, 41121, 41121, 41121, 41121,
965   };
966   const int c_shift = 21;
967 
968   const int c_count = m * n;
969   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
970   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
971                                                      ldc);
972   const OffsetColMap result_offset(c_offset, m);
973   const OffsetColMap result_mult_int(c_mult_int, m);
974   const int result_shift = c_shift;
975 
976   GemmContext gemm_context;
977   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
978       result_offset, result_mult_int, result_shift);
979   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
980                            DefaultL8R8BitDepthParams>(
981       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
982       output_pipeline);
983 
984   ResultStats stats;
985   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
986 
987   ResultStatsBounds bounds;
988   const bool good = CheckResultStatsBounds(stats, bounds);
989   printf("TestWithLargeDataPerChannelQuantization: %s\n",
990          good ? "PASS" : "FAIL");
991   ReportResultStats(stats, bounds);
992   Check(good);
993 }
994 
995 // Multithreading only activates when the result has more than 16 rows, and also
996 // (result rows) * (result cols) * depth >= 2 x 65 x 1024.  Size was selected
997 // to run in 3 threads.
998 //
999 // Based on the following floating point data:
1000 //   LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1001 //     1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1002 //     0.01, 0.02 in last 16 rows.
1003 //   RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1004 //   Varying boundaries were used for each 16 rows block of LHS, to test for
1005 //     correct indexing into offsets.
1006 //   Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1007 //     5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1008 //     16 rows.
TestMultithreadedPerChannelQuantization()1009 void TestMultithreadedPerChannelQuantization() {
1010   const int m = 64;
1011   const int n = 20;
1012   const int k = 160;
1013 
1014   // LHS, m x k.
1015   const std::array<std::int32_t, 4> lhs_offsets_terse{{
1016       0, -51, -85, -109,
1017   }};
1018   assert(lhs_offsets_terse.size() * 16 == m);
1019   const std::array<std::uint8_t, 4> lhs_first_el{{
1020       128, 153, 170, 182,
1021   }};
1022   assert(lhs_first_el.size() * 16 == m);
1023 
1024   // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1025   std::vector<std::uint8_t> a_data(m * k, 0);
1026   for (int i = 0; i < m; ++i) {
1027     a_data[i * k] = lhs_first_el[i / 16];
1028     a_data[i * k + 1] = 255;
1029     for (int j = 2; j < k; ++j) {
1030       a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1031     }
1032   }
1033 
1034   const int lda = k;
1035   // Given values at [i / 16].
1036   std::vector<std::int32_t> a_offset(m, 0);
1037   for (int i = 0; i < m; ++i) {
1038     a_offset[i] = lhs_offsets_terse[i / 16];
1039   }
1040 
1041   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1042   const OffsetColMap lhs_offset(&a_offset[0], m);
1043 
1044   // RHS, k x n.
1045   // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1046   std::vector<std::uint8_t> b_data(k * n, 0);
1047   b_data[0] = 128;
1048   b_data[1] = 255;
1049 
1050   const int ldb = k;
1051   std::int32_t b_offset = 0;
1052   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1053   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1054 
1055   // Result, m x n.
1056   // All zeros, except given values at (i / 16, 0).
1057   const std::array<std::uint8_t, 4> expected_c_terse{{
1058       142, 159, 182, 213,
1059   }};
1060   assert(expected_c_terse.size() * 16 == m);
1061   std::vector<std::uint8_t> expected_c_data(m * n, 0);
1062   for (int i = 0; i < m; ++i) {
1063     expected_c_data[i] = expected_c_terse[i / 16];
1064   }
1065 
1066   const int ldc = m;
1067   // All zeros.
1068   std::vector<std::int32_t> c_offset(m, 0);
1069   // Given values at [i / 16].
1070   const std::array<std::int32_t, 4> c_mult_int_terse{{
1071       3655, 5140, 7049, 9595,
1072   }};
1073   assert(c_mult_int_terse.size() * 16 == m);
1074   std::vector<std::int32_t> c_mult_int(m);
1075   for (int i = 0; i < m; ++i) {
1076     c_mult_int[i] = c_mult_int_terse[i / 16];
1077   }
1078 
1079   const int c_shift = 21;
1080 
1081   const int c_count = m * n;
1082   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1083   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1084                                                      ldc);
1085   const OffsetColMap result_offset(&c_offset[0], m);
1086   const OffsetColMap result_mult_int(&c_mult_int[0], m);
1087   const int result_shift = c_shift;
1088 
1089   GemmContext gemm_context;
1090   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1091       result_offset, result_mult_int, result_shift);
1092   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1093                            DefaultL8R8BitDepthParams>(
1094       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1095       output_pipeline);
1096 
1097   ResultStats stats;
1098   GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1099 
1100   ResultStatsBounds bounds;
1101   const bool good = CheckResultStatsBounds(stats, bounds);
1102   printf("TestMultithreadedPerChannelQuantization: %s\n",
1103          good ? "PASS" : "FAIL");
1104   ReportResultStats(stats, bounds);
1105   Check(good);
1106 }
1107 
1108 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1109 void TestWithSmallData() {
1110   const int m = 4;
1111   const int n = 2;
1112   const int k = 3;
1113   // Matrix A (LHS) is:
1114   // |  7 | 10 | 13 | 16 |
1115   // |  8 | 11 | 14 | 17 |
1116   // |  9 | 12 | 15 | 18 |
1117   const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1118   // Matrix B (RHS) is:
1119   // |  1 |  3 |  5 |
1120   // |  2 |  4 |  6 |
1121   const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1122   // Here are the results we expect, from hand calculations:
1123   // (1 * 7) + (3 * 8) + (5 * 9) = 76
1124   // (2 * 7) + (4 * 8) + (6 * 9) = 100
1125   // (1 * 10) + (3 * 11) + (5 * 12) = 103
1126   // (2 * 10) + (4 * 11) + (6 * 12) = 136
1127   // (1 * 13) + (3 * 14) + (5 * 15) = 130
1128   // (2 * 13) + (4 * 14) + (6 * 15) = 172
1129   // (1 * 16) + (3 * 17) + (5 * 18) = 157
1130   // (2 * 16) + (4 * 17) + (6 * 18) = 208
1131   // That means matrix C should be:
1132   // |  76 | 103 | 130 | 157 |
1133   // | 100 | 136 | 172 | 208 |
1134   const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1135 
1136   const int c_count = m * n;
1137   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1138 
1139   const bool is_a_transposed = true;
1140   const bool is_b_transposed = true;
1141   const bool is_c_transposed = true;
1142   const int lda = k;
1143   const int ldb = n;
1144   const int ldc = n;
1145 
1146   const int a_offset = 0;
1147   const int b_offset = 0;
1148   const int c_offset = 0;
1149   const int c_mult = 1;
1150   const int c_shift = 0;
1151 
1152   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1153       is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1154       a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1155       c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1156 
1157   ResultStats stats;
1158   GetResultStats(output_data.get(), expected_data, c_count, &stats);
1159 
1160   ResultStatsBounds bounds;
1161   const bool good = CheckResultStatsBounds(stats, bounds);
1162   printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1163   ReportResultStats(stats, bounds);
1164   Check(good);
1165 }
1166 
1167 // This is the most realistic test of how we'll be using the low-precision GEMM
1168 // function in applications. It takes in large input matrices that have been
1169 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1170 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1171                       int tolerance_median, int tolerance_max) {
1172   std::unique_ptr<std::uint8_t[]> output_data(
1173       new std::uint8_t[test_data::c_count]);
1174   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1175       test_data::is_a_transposed, test_data::is_b_transposed,
1176       test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1177       test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1178       test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1179       test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1180 
1181   ResultStats stats;
1182   GetResultStats(output_data.get(), test_data::expected_c_data,
1183                  test_data::c_count, &stats);
1184 
1185   ResultStatsBounds bounds;
1186   if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1187     bounds.med_unsigned_diff = tolerance_median;
1188     bounds.max_unsigned_diff = tolerance_max;
1189     bounds.med_signed_diff = 0;
1190     bounds.mean_signed_diff = 0.2f;
1191   }
1192 
1193   const bool good = CheckResultStatsBounds(stats, bounds);
1194   printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1195          GetBitDepthName(BitDepth));
1196   ReportResultStats(stats, bounds);
1197   Check(good);
1198 }
1199 
1200 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1201 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1202                       int result_mult_int, int result_shift) {
1203   Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1204   Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1205   Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1206   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1207   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1208   const int lhs_offset = 12;
1209   const int rhs_offset = -34;
1210 
1211   // Test an empty pipeline, i.e. returning raw int32 accumulators.
1212   auto empty_pipeline = std::make_tuple();
1213   GemmContext context;
1214   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1215       &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1216       rhs_offset, empty_pipeline);
1217 
1218   for (int r = 0; r < rows; r++) {
1219     for (int c = 0; c < cols; c++) {
1220       std::int32_t expected = 0;
1221       for (int d = 0; d < depth; d++) {
1222         std::int32_t lhs_val =
1223             static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1224         std::int32_t rhs_val =
1225             static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1226         expected += lhs_val * rhs_val;
1227       }
1228       Check(expected == result_raw_int32(r, c));
1229     }
1230   }
1231 
1232   // Test a pipeline with only the quantize-down stage, still returning
1233   // unclamped (but scaled) int32's
1234   OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1235   quantize_down_stage.result_offset = result_offset;
1236   quantize_down_stage.result_mult_int = result_mult_int;
1237   quantize_down_stage.result_shift = result_shift;
1238   auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1239   Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1240   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1241       &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1242       lhs_offset, rhs_offset, quantize_down_pipeline);
1243 
1244   std::int64_t sum = 0;
1245   for (int r = 0; r < rows; r++) {
1246     for (int c = 0; c < cols; c++) {
1247       std::int32_t raw = result_raw_int32(r, c);
1248       std::int32_t expected = RoundingDivideByPOT(
1249           (raw + result_offset) * result_mult_int, result_shift);
1250       Check(expected == result_quantized_down_int32(r, c));
1251       sum += expected;
1252     }
1253   }
1254   std::int64_t avg = sum / (rows * cols);
1255   // Test that the average quantized-down value falls reasonably in the
1256   // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1257   // adjusted.
1258   Check(avg >= 64 && avg <= 192);
1259 
1260   // Test the familiar default pipeline consisting of quantize-down and
1261   // clamp-and-cast-to-uint8.
1262   OutputStageSaturatingCastToUint8 saturating_cast_stage;
1263   auto quantize_down_and_saturating_cast_pipeline =
1264       std::make_tuple(quantize_down_stage, saturating_cast_stage);
1265   Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1266                                                                           cols);
1267   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1268       &context, lhs.const_map(), rhs.const_map(),
1269       &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1270       quantize_down_and_saturating_cast_pipeline);
1271 
1272   for (int r = 0; r < rows; r++) {
1273     for (int c = 0; c < cols; c++) {
1274       std::int32_t quantized = result_quantized_down_int32(r, c);
1275       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1276       Check(expected == result_quantized_down_saturated_uint8(r, c));
1277     }
1278   }
1279 
1280   // Test a bias-addition with row-vector
1281   std::vector<std::int32_t> row_vector_data(cols);
1282   std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1283                                                                          500);
1284   for (int i = 0; i < cols; i++) {
1285     row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1286   }
1287   typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1288   RowVectorMap row_vector_map(row_vector_data.data(), cols);
1289   OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1290   row_bias_addition_stage.bias_vector = row_vector_map;
1291   auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1292   Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1293   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1294       &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1295       lhs_offset, rhs_offset, row_bias_addition_pipeline);
1296   for (int r = 0; r < rows; r++) {
1297     for (int c = 0; c < cols; c++) {
1298       std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1299       Check(expected == result_of_row_bias_addition(r, c));
1300     }
1301   }
1302 
1303   // Test a bias-addition with column-vector
1304   std::vector<std::int32_t> col_vector_data(rows);
1305   for (int i = 0; i < rows; i++) {
1306     col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1307   }
1308   typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1309   ColVectorMap col_vector_map(col_vector_data.data(), rows);
1310   OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1311   col_bias_addition_stage.bias_vector = col_vector_map;
1312   auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1313   Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1314   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1315       &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1316       lhs_offset, rhs_offset, col_bias_addition_pipeline);
1317   for (int r = 0; r < rows; r++) {
1318     for (int c = 0; c < cols; c++) {
1319       std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1320       Check(expected == result_of_col_bias_addition(r, c));
1321     }
1322   }
1323 
1324   // Test a clamp
1325   OutputStageClamp clamp_stage;
1326   // Determine min and max of raw int32 accumulators
1327   std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1328   std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1329   for (int r = 0; r < rows; r++) {
1330     for (int c = 0; c < cols; c++) {
1331       raw_min = std::min(raw_min, result_raw_int32(r, c));
1332       raw_max = std::max(raw_max, result_raw_int32(r, c));
1333     }
1334   }
1335   // Pick some interesting clamp min/max bounds
1336   clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1337   clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1338   assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1339          clamp_stage.max <= raw_max);
1340   auto clamp_pipeline = std::make_tuple(clamp_stage);
1341   Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1342   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1343       &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1344       rhs_offset, clamp_pipeline);
1345   for (int r = 0; r < rows; r++) {
1346     for (int c = 0; c < cols; c++) {
1347       std::int32_t raw = result_raw_int32(r, c);
1348       std::int32_t expected =
1349           std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1350       Check(expected == result_clamped(r, c));
1351     }
1352   }
1353 
1354   // Test tanh
1355   OutputStageTanh tanh_stage;
1356   const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1357   const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1358   tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1359   tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1360   auto tanh_pipeline = std::make_tuple(tanh_stage);
1361   Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1362   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1363       &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1364       rhs_offset, tanh_pipeline);
1365   for (int r = 0; r < rows; r++) {
1366     for (int c = 0; c < cols; c++) {
1367       std::int32_t raw = result_raw_int32(r, c);
1368       double real_input =
1369           double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1370       double expected = std::tanh(real_input);
1371       std::int32_t actual_int32 = result_tanh(r, c);
1372       double actual =
1373           double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1374       Check(std::abs(expected - actual) < 2e-4);
1375     }
1376   }
1377 
1378   // Test a pipeline with bias and clamp
1379   auto bias_clamp_pipeline =
1380       std::make_tuple(col_bias_addition_stage, clamp_stage);
1381   Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1382   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1383       &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1384       lhs_offset, rhs_offset, bias_clamp_pipeline);
1385   for (int r = 0; r < rows; r++) {
1386     for (int c = 0; c < cols; c++) {
1387       std::int32_t raw = result_raw_int32(r, c);
1388       std::int32_t biased = raw + col_vector_data[r];
1389       std::int32_t expected =
1390           std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1391       Check(expected == result_biased_clamped(r, c));
1392     }
1393   }
1394 
1395   // Test a full pipeline with bias and clamp and quantization down to 8bit
1396   // result
1397   auto bias_clamp_quantize_cast_pipeline =
1398       std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1399                       saturating_cast_stage);
1400   Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1401       rows, cols);
1402   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1403       &context, lhs.const_map(), rhs.const_map(),
1404       &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1405       bias_clamp_quantize_cast_pipeline);
1406   for (int r = 0; r < rows; r++) {
1407     for (int c = 0; c < cols; c++) {
1408       std::int32_t quantized = RoundingDivideByPOT(
1409           (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1410           result_shift);
1411       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1412       Check(expected == result_biased_clamped_quantized_casted(r, c));
1413     }
1414   }
1415 
1416   // Test a pipeline with the fixed-point-multiplier variant stage for the
1417   // quantizing down of 32bit accumulators.
1418   //
1419   // First, figure appropriate fixedpoint multiplier and shift values.
1420   std::int32_t result_fixedpoint_multiplier = result_mult_int;
1421   std::int32_t result_fixedpoint_shift = result_shift;
1422   Check(result_mult_int > 0);
1423   Check(result_shift > 0);
1424   result_fixedpoint_multiplier = result_mult_int;
1425   result_fixedpoint_shift = result_shift - 31;
1426   while (result_fixedpoint_multiplier < (1 << 30)) {
1427     result_fixedpoint_multiplier <<= 1;
1428     result_fixedpoint_shift++;
1429   }
1430   Check(result_fixedpoint_shift >= 0);
1431   // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1432   OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1433       quantize_down_by_fixedpoint_stage;
1434   quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1435       static_cast<std::int32_t>(
1436           round(static_cast<double>(result_offset * result_mult_int) /
1437                 (1 << result_shift)));
1438   quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1439       result_fixedpoint_multiplier;
1440   quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1441   auto quantize_down_by_fixedpoint_pipeline =
1442       std::make_tuple(quantize_down_by_fixedpoint_stage);
1443   Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1444       rows, cols);
1445   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1446       &context, lhs.const_map(), rhs.const_map(),
1447       &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1448       quantize_down_by_fixedpoint_pipeline);
1449 
1450   std::vector<std::int32_t> diffs_caused_by_fixedpoint;
1451   for (int r = 0; r < rows; r++) {
1452     for (int c = 0; c < cols; c++) {
1453       const std::int32_t actual =
1454           result_quantized_down_by_fixedpoint_int32(r, c);
1455       const std::int32_t raw = result_raw_int32(r, c);
1456       const std::int32_t expected =
1457           quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1458           RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1459                                   raw, result_fixedpoint_multiplier),
1460                               result_fixedpoint_shift);
1461       Check(actual == expected);
1462     }
1463   }
1464 
1465   // Test the variant of the familiar default pipeline consisting of
1466   // quantize-down and
1467   // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1468   // downscaling.
1469   auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1470       std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1471   Matrix<std::uint8_t, ResultOrder>
1472       result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1473   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1474       &context, lhs.const_map(), rhs.const_map(),
1475       &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1476       rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1477 
1478   for (int r = 0; r < rows; r++) {
1479     for (int c = 0; c < cols; c++) {
1480       std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1481       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1482       Check(expected ==
1483             result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1484     }
1485   }
1486 
1487   printf("TestOutputStages: PASS with ResultOrder=%s\n",
1488          OrderName(ResultOrder));
1489 }
1490 
1491 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1492 template <typename BitDepthParams>
TestExhaustively()1493 void TestExhaustively() {
1494   GemmContext context;
1495 
1496   // Test the internal GEMM interfaces
1497   test_gemm<
1498       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1499                               std::uint8_t, BitDepthParams>>(&context);
1500 
1501   test_gemm<
1502       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1503                              std::uint8_t, BitDepthParams>>(&context);
1504 
1505   // Test the public GEMM interfaces
1506   test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1507 
1508   // Test GEMV cases (internal interfaces)
1509   test_gemv<
1510       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1511                               std::uint8_t, BitDepthParams>>(&context);
1512 
1513   test_gemv<
1514       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1515                              std::uint8_t, BitDepthParams>>(&context);
1516 
1517   // Test GEMV cases (public interfaces)
1518   test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1519 }
1520 
1521 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1522 void TestExhaustivelyEightBitIntGemm() {
1523   GemmContext context;
1524   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1525   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1526   test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1527 }
1528 
TestKernels()1529 void TestKernels() {
1530   GemmContext context;
1531 
1532   // Test specific kernels with various different formats,
1533   // to exercises corner cases especially in the packing code.
1534   test_gemm_kernel<
1535       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1536                                    KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1537       &context);
1538 
1539   test_gemm_kernel<
1540       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1541                                    KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1542       &context);
1543 
1544   test_gemm_kernel<
1545       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1546                                    KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1547       &context);
1548 
1549   test_gemm_kernel<ReferenceKernel<KernelFormat<
1550       KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1551       KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1552 
1553   test_gemm_kernel<ReferenceKernel<KernelFormat<
1554       KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1555       KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1556 
1557   test_gemm_kernel<ReferenceKernel<KernelFormat<
1558       KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1559       KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1560 
1561   test_gemm_kernel<ReferenceKernel<KernelFormat<
1562       KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1563       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1564 
1565   test_gemm_kernel<ReferenceKernel<KernelFormat<
1566       KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1567       KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1568 
1569   test_gemm_kernel<ReferenceKernel<KernelFormat<
1570       KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1571       KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1572 }
1573 
1574 #endif  // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1575 
1576 template <typename BitDepthParams>
TestOutputStages()1577 void TestOutputStages() {
1578   // Test non-default output pipelines with various combinations of
1579   // output stages.
1580   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1581   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1582   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1583                                                        14);
1584   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1585                                                        14);
1586 }
1587 
test()1588 void test() {
1589 #ifdef GEMMLOWP_TEST_PROFILE
1590   RegisterCurrentThreadForProfiling();
1591   StartProfiling();
1592 #endif
1593 
1594   // Run a first quick test against hand-calculated data.
1595   TestWithSmallData();
1596 
1597 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1598   TestExhaustively<DefaultL8R8BitDepthParams>();
1599   TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1600   TestExhaustively<DefaultL7R5BitDepthParams>();  // legacy, same as L8R8
1601   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1602   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1603   TestKernels();
1604 #endif
1605 
1606   // Run against actual data from a network evaluation.
1607   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1608   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1609 
1610   // Test non-default output pipelines with various combinations of
1611   // output stages.
1612   TestOutputStages<DefaultL8R8BitDepthParams>();
1613   TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1614 
1615   // Test per channel quantization.
1616   TestWithSmallDataPerChannelQuantization();
1617   TestWithLargeDataPerChannelQuantization();
1618   TestMultithreadedPerChannelQuantization();
1619 #ifdef GEMMLOWP_TEST_PROFILE
1620   FinishProfiling();
1621 #endif
1622 
1623   std::cerr << "All tests passed." << std::endl;
1624 
1625   // We have been testing the eight_bit_int_gemm, so we should free its
1626   // persistent
1627   // resources now to avoid having leak-checking tools report leaks.
1628   eight_bit_int_gemm::FreePersistentResources();
1629 }
1630 
1631 }  // end namespace gemmlowp
1632 
1633 // For iOS, we need to define our own main(), so skip it here.
1634 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1635 int main() { gemmlowp::test(); }
1636 #endif
1637