• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test.h"
16 
17 #include <unistd.h>
18 #include <array>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <ctime>
22 #include <iostream>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #ifdef __APPLE__
27 #include <TargetConditionals.h>
28 #endif
29 
30 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
31 #include "../internal/kernel_reference.h"
32 #include "test_data.h"
33 
34 namespace gemmlowp {
35 
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)36 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
37                               bool transpose_c, int m, int n, int k,
38                               const std::uint8_t* a, std::int32_t a_offset,
39                               int lda, const std::uint8_t* b,
40                               std::int32_t b_offset, int ldb, std::uint8_t* c,
41                               std::int32_t c_offset, std::int32_t c_mult_int,
42                               std::int32_t c_shift, int ldc) {
43   ScopedProfilingLabel("ReferenceEightBitIntGemm");
44   assert((c_shift >= 0) && (c_shift <= 32));
45 
46   assert(a != nullptr);
47   assert(b != nullptr);
48   assert(c != nullptr);
49 
50   int a_i_stride;
51   int a_l_stride;
52   if (transpose_a) {
53     a_i_stride = lda;
54     a_l_stride = 1;
55   } else {
56     a_i_stride = 1;
57     a_l_stride = lda;
58   }
59   int b_j_stride;
60   int b_l_stride;
61   if (transpose_b) {
62     b_j_stride = 1;
63     b_l_stride = ldb;
64   } else {
65     b_j_stride = ldb;
66     b_l_stride = 1;
67   }
68   int c_i_stride;
69   int c_j_stride;
70   if (transpose_c) {
71     c_i_stride = ldc;
72     c_j_stride = 1;
73   } else {
74     c_i_stride = 1;
75     c_j_stride = ldc;
76   }
77   int i, j, l;
78 
79   const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
80 
81   for (j = 0; j < n; j++) {
82     for (i = 0; i < m; i++) {
83       std::int32_t total = 0;
84       for (l = 0; l < k; l++) {
85         const int a_index = i * a_i_stride + l * a_l_stride;
86         const std::uint8_t a_as_byte = a[a_index];
87         const std::int32_t a_as_int =
88             static_cast<std::int32_t>(a_as_byte) + a_offset;
89         const int b_index = j * b_j_stride + l * b_l_stride;
90         const std::uint8_t b_as_byte = b[b_index];
91         const std::int32_t b_as_int =
92             static_cast<std::int32_t>(b_as_byte) + b_offset;
93         const std::int32_t mult_as_int = a_as_int * b_as_int;
94         total += mult_as_int;
95       }
96       std::int32_t output =
97           (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
98       if (output > 255) {
99         output = 255;
100       }
101       if (output < 0) {
102         output = 0;
103       }
104       const int c_index = i * c_i_stride + j * c_j_stride;
105       c[c_index] = static_cast<std::uint8_t>(output);
106     }
107   }
108 }
109 
110 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
111 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
112 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
113 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
114 
115 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
116 // interface, so we can use the same testing code to test all of them
117 
118 template <typename Kernel, typename Scalar, typename tBitDepthParams>
119 struct SingleThreadGemmWrapper {
120   typedef tBitDepthParams BitDepthParams;
121 
Namegemmlowp::SingleThreadGemmWrapper122   static const char* Name() {
123     static char buf[256];
124     snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
125     return buf;
126   }
127 
128   typedef SingleThreadGemmContext Context;
129 
130   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper131   static bool Gemm(Context* context,
132                    const MatrixMap<const Scalar, LhsOrder>& lhs,
133                    const MatrixMap<const Scalar, RhsOrder>& rhs,
134                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
135                    int rhs_offset, int result_offset, int result_mult_int,
136                    int result_shift) {
137     ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
138     const int rows = lhs.rows();
139     const int cols = rhs.cols();
140     if (rows < cols) {
141       // SingleThreadGemm is never called with rows < cols.
142       // That case is handled earlier.
143       return false;
144     }
145     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
146     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
147     SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
148                      LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
149                      OffsetRowDup>(
150         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
151         rhs_offset_vector,
152         MakeStandardOutputPipeline(result_offset, result_mult_int,
153                                    result_shift));
154     return true;
155   }
156 };
157 
158 template <typename Kernel, typename Scalar, typename tBitDepthParams>
159 struct MultiThreadGemmWrapper {
160   typedef tBitDepthParams BitDepthParams;
161 
Namegemmlowp::MultiThreadGemmWrapper162   static const char* Name() {
163     static char buf[256];
164     snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
165     return buf;
166   }
167 
168   typedef MultiThreadGemmContext Context;
169 
170   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper171   static bool Gemm(Context* context,
172                    const MatrixMap<const Scalar, LhsOrder>& lhs,
173                    const MatrixMap<const Scalar, RhsOrder>& rhs,
174                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
175                    int rhs_offset, int result_offset, int result_mult_int,
176                    int result_shift) {
177     ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
178     context->set_max_num_threads(0);
179     const int rows = lhs.rows();
180     const int cols = rhs.cols();
181     if (rows < cols) {
182       // SingleThreadGemm is never called with rows < cols.
183       // That case is handled earlier.
184       return false;
185     }
186     const OffsetColDup lhs_offset_vector(lhs_offset, rows);
187     const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
188     MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
189                     LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
190                     OffsetRowDup>(
191         context, Kernel(), lhs, rhs, result, lhs_offset_vector,
192         rhs_offset_vector,
193         MakeStandardOutputPipeline(result_offset, result_mult_int,
194                                    result_shift));
195     return true;
196   }
197 };
198 
199 template <typename Scalar, typename tBitDepthParams>
200 struct PublicGemmWrapper {
201   typedef tBitDepthParams BitDepthParams;
202 
Namegemmlowp::PublicGemmWrapper203   static const char* Name() { return "public Gemm"; }
204 
205   typedef GemmContext Context;
206 
207   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper208   static bool Gemm(Context* context,
209                    const MatrixMap<const Scalar, LhsOrder>& lhs,
210                    const MatrixMap<const Scalar, RhsOrder>& rhs,
211                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
212                    int rhs_offset, int result_offset, int result_mult_int,
213                    int result_shift) {
214     ScopedProfilingLabel("PublicGemmWrapper::Gemm");
215     gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
216                    ResultOrder>(context, lhs, rhs, result, lhs_offset,
217                                 rhs_offset, result_offset, result_mult_int,
218                                 result_shift);
219     return true;
220   }
221 };
222 
223 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
224 struct BitDepthParamsForSettings {};
225 
226 template <>
227 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
228     : DefaultL8R8BitDepthParams {};
229 
230 template <>
231 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
232     : DefaultL7R5BitDepthParams {};
233 
234 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
235 struct EightBitIntGemmWrapper {
236   typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
237 
Namegemmlowp::EightBitIntGemmWrapper238   static const char* Name() { return "EightBitIntGemm"; }
239 
240   typedef void Context;
241 
242   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper243   static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
244                    const MatrixMap<const Scalar, RhsOrder>& rhs,
245                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
246                    int rhs_offset, int result_offset, int result_mult_int,
247                    int result_shift) {
248     ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
249     const bool transpose_c = ResultOrder == MapOrder::RowMajor;
250     const bool transpose_a = LhsOrder == MapOrder::RowMajor;
251     const bool transpose_b = RhsOrder == MapOrder::RowMajor;
252     eight_bit_int_gemm::EightBitIntGemm(
253         transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
254         lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
255         rhs_offset, rhs.stride(), result->data(), result_offset,
256         result_mult_int, result_shift, result->stride(), BitDepth);
257     return true;
258   }
259 };
260 
261 template <typename Scalar>
262 struct ReferenceEightBitIntGemmWrapper {
263   typedef DefaultL8R8BitDepthParams BitDepthParams;
264 
Namegemmlowp::ReferenceEightBitIntGemmWrapper265   static const char* Name() { return "ReferenceEightBitIntGemm"; }
266 
267   template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper268   static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
269                    const MatrixMap<const Scalar, LhsOrder>& lhs,
270                    const MatrixMap<const Scalar, RhsOrder>& rhs,
271                    MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
272                    int rhs_offset, int result_offset, int result_mult_int,
273                    int result_shift) {
274     ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
275     ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
276                              rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
277                              lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
278                              result->data(), result_offset, result_mult_int,
279                              result_shift, result->stride());
280     return true;
281   }
282 };
283 
OrderName(MapOrder order)284 const char* OrderName(MapOrder order) {
285   return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
286 }
287 
288 struct ResultStats {
ResultStatsgemmlowp::ResultStats289   ResultStats()
290       : count(0),
291         med_val(0),
292         mean_signed_diff(0),
293         med_signed_diff(0),
294         med_unsigned_diff(0),
295         max_unsigned_diff(0) {}
296 
297   int count;
298   int med_val;
299   float mean_signed_diff;
300   int med_signed_diff;
301   int med_unsigned_diff;
302   int max_unsigned_diff;
303 
304   std::vector<int> count_diff_by_pot_slice;
305 };
306 
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)307 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
308                     size_t count, ResultStats* stats) {
309   ScopedProfilingLabel("GetResultStats");
310   std::vector<std::uint8_t> results;
311   std::vector<std::int16_t> signed_diffs;
312   std::vector<std::uint8_t> unsigned_diffs;
313   std::int64_t signed_diffs_sum = 0;
314   for (size_t i = 0; i < count; i++) {
315     results.push_back(actual[i]);
316     std::int16_t signed_diff = actual[i] - expected[i];
317     signed_diffs.push_back(signed_diff);
318     unsigned_diffs.push_back(std::abs(signed_diff));
319     signed_diffs_sum += signed_diff;
320   }
321 
322   std::sort(results.begin(), results.end());
323   std::sort(signed_diffs.begin(), signed_diffs.end());
324   std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
325 
326   const size_t middle = count / 2;
327 
328   stats->count = count;
329   stats->med_val = results[middle];
330   stats->mean_signed_diff = float(signed_diffs_sum) / count;
331   stats->med_signed_diff = signed_diffs[middle];
332   stats->med_unsigned_diff = unsigned_diffs[middle];
333   stats->max_unsigned_diff = unsigned_diffs.back();
334 
335   // Size 9 for 9 different POT values: 2^0, ..., 2^8
336   stats->count_diff_by_pot_slice.resize(9);
337   auto cur = unsigned_diffs.begin();
338   size_t checksum = 0;
339   for (int exponent = 0; exponent < 9; exponent++) {
340     int pot = 1 << exponent;
341     auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
342     checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
343     cur = next;
344   }
345   assert(checksum == count);
346 }
347 
348 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds349   ResultStatsBounds()
350       : mean_signed_diff(0),
351         med_signed_diff(0),
352         med_unsigned_diff(0),
353         max_unsigned_diff(0) {}
354 
355   float mean_signed_diff;
356   int med_signed_diff;
357   int med_unsigned_diff;
358   int max_unsigned_diff;
359 };
360 
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)361 bool CheckResultStatsBounds(const ResultStats& stats,
362                             const ResultStatsBounds& bounds) {
363   return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
364          stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
365          std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
366          std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
367 }
368 
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)369 void ReportResultStats(const ResultStats& stats,
370                        const ResultStatsBounds& bounds) {
371   printf("    number of matrix entries: %d\n", stats.count);
372   printf("    median value: %d\n", stats.med_val);
373   printf("    median unsigned diff: %d (tolerating %d)\n",
374          stats.med_unsigned_diff, bounds.med_unsigned_diff);
375   printf("    max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
376          bounds.max_unsigned_diff);
377   printf("    median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
378          bounds.med_signed_diff);
379   printf("    mean signed diff: %.3g (tolerating %.3g)\n",
380          stats.mean_signed_diff, bounds.mean_signed_diff);
381 
382   printf("No error: %.2f %% of entries\n",
383          100.f * stats.count_diff_by_pot_slice[0] / stats.count);
384   for (int exponent = 1; exponent < 9; exponent++) {
385     printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
386            (1 << exponent) - 1,
387            100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
388   }
389 }
390 
391 // Our approach to choosing result_shift values for testing, is bisection.
392 // This function takes an interval, [result_shift_min .. result_shift_max].
393 // If too much saturation occurred in either direction, it bisects accordingly,
394 // recursing until the interval contains only one value.
395 // The primary reason why we prefer this over computing optimal shift values,
396 // is that we actually want to exercise some saturation, as there is nontrivial
397 // code handling that in gemmlowp.
398 // Secondarily, this is faster than computing optimal shifts, since in 90% of
399 // cases the first-tried shift value 16 turns out to be good enough.
400 template <typename GemmWrapper, typename LhsType, typename RhsType,
401           typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)402 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
403                     const RhsType& rhs, ResultType* result, int lhs_offset,
404                     int rhs_offset, int result_offset, int result_mult_int,
405                     int result_shift_min, int result_shift_max) {
406   const int rows = lhs.rows();
407   const int cols = rhs.cols();
408   Check(lhs.cols() == rhs.rows());
409   const int depth = lhs.cols();
410 
411   const int result_shift = (result_shift_min + result_shift_max) / 2;
412 
413   if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
414                          &result->map(), lhs_offset, rhs_offset, result_offset,
415                          result_mult_int, result_shift)) {
416     // Internal GEMM functions are not required to handle all cases
417     // (e.g. rows < cols) as these are supposed to have been handled
418     // ahead of them. Their test wrappers return false in that case.
419     return;
420   }
421 
422   typedef typename ResultType::Scalar Scalar;
423   static const MapOrder kLhsOrder = LhsType::kOrder;
424   static const MapOrder kRhsOrder = RhsType::kOrder;
425   static const MapOrder kResultOrder = ResultType::kOrder;
426   ResultType ref_result(rows, cols);
427   const bool transpose_c = kResultOrder == MapOrder::RowMajor;
428   const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
429   const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
430   ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
431       transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
432       &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
433       result_shift);
434 
435   typedef typename GemmWrapper::BitDepthParams BitDepthParams;
436 
437   ResultStats stats;
438   GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
439 
440   // Adjust shifts until we get meaningful results
441   int new_result_shift_min = result_shift_min;
442   int new_result_shift_max = result_shift_max;
443   bool retry = false;
444 
445   if (stats.med_val < 32) {
446     new_result_shift_max = (result_shift_min + result_shift_max) / 2;
447     retry = true;
448   }
449 
450   if (stats.med_val > 224) {
451     new_result_shift_min = (result_shift_min + result_shift_max) / 2;
452     retry = true;
453   }
454 
455   if (retry) {
456     if (result_shift_min != result_shift_max) {
457       test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
458                                   rhs_offset, result_offset, result_mult_int,
459                                   new_result_shift_min, new_result_shift_max);
460     }
461     return;
462   }
463 
464   ResultStatsBounds bounds;
465 
466   // Check results
467   const bool good = CheckResultStatsBounds(stats, bounds);
468 
469   printf(
470       "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
471       good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
472       OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
473       lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
474 
475   if (!good) {
476     ReportResultStats(stats, bounds);
477 
478     int bad_coeffs_printed = 0;
479     for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
480       for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
481         if (ref_result(r, c) != (*result)(r, c)) {
482           printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
483                  ref_result(r, c), (*result)(r, c));
484           bad_coeffs_printed++;
485         }
486       }
487     }
488   }
489 
490   Check(good);
491 }
492 
493 template <typename GemmWrapper, typename LhsType, typename RhsType,
494           typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)495 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
496                const RhsType& rhs, ResultType* result, int lhs_offset,
497                int rhs_offset, int result_offset, int result_mult_int) {
498   test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
499                               result_offset, result_mult_int, 0, 32);
500 }
501 
502 enum class WhatParamsToTest {
503   All,
504   OnlyGenericCase,
505 };
506 
507 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
508           MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)509 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
510                int cols, WhatParamsToTest params_to_test) {
511   typedef std::uint8_t Scalar;
512   typedef Matrix<Scalar, LhsOrder> LhsType;
513   using BitDepthParams = typename GemmWrapper::BitDepthParams;
514   LhsType lhs(rows, depth);
515   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
516   typedef Matrix<Scalar, RhsOrder> RhsType;
517   RhsType rhs(depth, cols);
518   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
519   typedef Matrix<Scalar, ResultOrder> ResultType;
520   ResultType result(rows, cols);
521   MakeZero(&result);
522 
523   if (params_to_test == WhatParamsToTest::All) {
524     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
525     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
526     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
527     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
528     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
529     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
530     test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
531   }
532   test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
533 }
534 
535 enum class WhatOrdersToTest { All, OnlyRCC };
536 
537 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)538 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
539                int cols, WhatParamsToTest params_to_test,
540                WhatOrdersToTest orders_to_test) {
541 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder)         \
542   do {                                                             \
543     test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
544               MapOrder::ResultOrder>(context, rows, depth, cols,   \
545                                      params_to_test);              \
546   } while (false)
547 
548   if (orders_to_test == WhatOrdersToTest::All) {
549     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
550     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
551     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
552     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
553 
554     GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
555     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
556     GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
557     GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
558   } else {
559     GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
560   }
561 
562 #undef GEMMLOWP_ONE_TEST
563 }
564 
565 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)566 void test_gemm_kernel(MultiThreadGemmContext* context) {
567   typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
568                                  DefaultL8R8BitDepthParams>
569       GemmWrapper;
570   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
571                          WhatOrdersToTest::OnlyRCC);
572   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
573                          WhatOrdersToTest::OnlyRCC);
574   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
575                          WhatOrdersToTest::OnlyRCC);
576   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
577                          WhatOrdersToTest::OnlyRCC);
578   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
579                          WhatOrdersToTest::OnlyRCC);
580   test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
581                          WhatOrdersToTest::OnlyRCC);
582   test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
583                          WhatOrdersToTest::OnlyRCC);
584   test_gemm<GemmWrapper>(context, 200, 200, 200,
585                          WhatParamsToTest::OnlyGenericCase,
586                          WhatOrdersToTest::All);
587   test_gemm<GemmWrapper>(context, 50, 5000, 50,
588                          WhatParamsToTest::OnlyGenericCase,
589                          WhatOrdersToTest::OnlyRCC);
590 }
591 
592 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)593 void test_gemm(typename GemmWrapper::Context* context) {
594   test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
595                          WhatOrdersToTest::OnlyRCC);
596   test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
597                          WhatOrdersToTest::OnlyRCC);
598   test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
599                          WhatOrdersToTest::OnlyRCC);
600   test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
601                          WhatOrdersToTest::OnlyRCC);
602   test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
603                          WhatOrdersToTest::OnlyRCC);
604   test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
605                          WhatOrdersToTest::OnlyRCC);
606   test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
607                          WhatOrdersToTest::OnlyRCC);
608   test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
609                          WhatOrdersToTest::OnlyRCC);
610   test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
611                          WhatOrdersToTest::OnlyRCC);
612   test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
613                          WhatOrdersToTest::OnlyRCC);
614   test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
615                          WhatOrdersToTest::OnlyRCC);
616   test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
617                          WhatOrdersToTest::OnlyRCC);
618   test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
619                          WhatOrdersToTest::All);
620   test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
621                          WhatOrdersToTest::OnlyRCC);
622   test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
623                          WhatOrdersToTest::OnlyRCC);
624   test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
625                          WhatOrdersToTest::All);
626   test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
627                          WhatOrdersToTest::OnlyRCC);
628 
629   test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
630                          WhatOrdersToTest::OnlyRCC);
631   test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
632                          WhatOrdersToTest::OnlyRCC);
633   test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
634                          WhatOrdersToTest::OnlyRCC);
635   test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
636                          WhatOrdersToTest::OnlyRCC);
637   test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
638                          WhatOrdersToTest::OnlyRCC);
639   test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
640                          WhatOrdersToTest::OnlyRCC);
641 
642   test_gemm<GemmWrapper>(context, 512, 512, 512,
643                          WhatParamsToTest::OnlyGenericCase,
644                          WhatOrdersToTest::OnlyRCC);
645   test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
646                          WhatParamsToTest::OnlyGenericCase,
647                          WhatOrdersToTest::OnlyRCC);
648   test_gemm<GemmWrapper>(context, 567, 2345, 123,
649                          WhatParamsToTest::OnlyGenericCase,
650                          WhatOrdersToTest::OnlyRCC);
651   test_gemm<GemmWrapper>(context, 100, 5000, 100,
652                          WhatParamsToTest::OnlyGenericCase,
653                          WhatOrdersToTest::OnlyRCC);
654   test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
655                          WhatOrdersToTest::OnlyRCC);
656   test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
657                          WhatOrdersToTest::OnlyRCC);
658   test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
659                          WhatOrdersToTest::OnlyRCC);
660   test_gemm<GemmWrapper>(context, 1, 1000, 1000,
661                          WhatParamsToTest::OnlyGenericCase,
662                          WhatOrdersToTest::OnlyRCC);
663   test_gemm<GemmWrapper>(context, 1000, 1, 1000,
664                          WhatParamsToTest::OnlyGenericCase,
665                          WhatOrdersToTest::OnlyRCC);
666   test_gemm<GemmWrapper>(context, 1000, 1000, 1,
667                          WhatParamsToTest::OnlyGenericCase,
668                          WhatOrdersToTest::OnlyRCC);
669   test_gemm<GemmWrapper>(context, 777, 3456, 1,
670                          WhatParamsToTest::OnlyGenericCase,
671                          WhatOrdersToTest::OnlyRCC);
672   test_gemm<GemmWrapper>(context, 4567, 555, 1,
673                          WhatParamsToTest::OnlyGenericCase,
674                          WhatOrdersToTest::OnlyRCC);
675 
676   // Test all storage orders
677   test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
678                          WhatOrdersToTest::All);
679   test_gemm<GemmWrapper>(context, 300, 400, 500,
680                          WhatParamsToTest::OnlyGenericCase,
681                          WhatOrdersToTest::All);
682 }
683 
684 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)685 void test_gemv(typename GemmWrapper::Context* context) {
686   test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
687                          WhatOrdersToTest::OnlyRCC);
688   test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
689                          WhatOrdersToTest::OnlyRCC);
690   test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
691                          WhatOrdersToTest::OnlyRCC);
692   test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
693                          WhatOrdersToTest::OnlyRCC);
694   test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
695                          WhatOrdersToTest::OnlyRCC);
696   test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
697                          WhatOrdersToTest::OnlyRCC);
698   test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
699                          WhatOrdersToTest::OnlyRCC);
700   test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
701                          WhatOrdersToTest::OnlyRCC);
702   test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
703                          WhatOrdersToTest::OnlyRCC);
704   test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
705                          WhatOrdersToTest::OnlyRCC);
706   test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
707                          WhatOrdersToTest::OnlyRCC);
708   test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
709                          WhatOrdersToTest::OnlyRCC);
710 
711   // Test all storage orders
712   test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
713                          WhatOrdersToTest::All);
714   test_gemm<GemmWrapper>(context, 300, 400, 1,
715                          WhatParamsToTest::OnlyGenericCase,
716                          WhatOrdersToTest::All);
717 }
718 
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)719 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
720   switch (b) {
721     case eight_bit_int_gemm::BitDepthSetting::A8B8:
722       return "Lhs: 8 bit, Rhs: 8 bit";
723     case eight_bit_int_gemm::BitDepthSetting::A5B7:
724       return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
725     default:
726       abort();
727       return nullptr;
728   }
729 }
730 
731 // Runs a small set of hand-picked data for per-channel quantized data.
732 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
733 // image.
TestWithSmallDataPerChannelQuantization()734 void TestWithSmallDataPerChannelQuantization() {
735   const int m = 2;
736   const int n = 9;
737   const int k = 12;
738 
739   // 12 x 2, columnwise.
740   const std::uint8_t a_data[] = {0,  0,   0,   0,   0,  0,   0,   0,
741                                  0,  255, 255, 255, 64, 64,  64,  64,
742                                  64, 64,  0,   0,   0,  255, 255, 255};
743   const int lda = k;
744   int a_offset[] = {0, -64};
745   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
746   const OffsetColMap lhs_offset(a_offset, m);
747 
748   // 12 x 9, columnwise.
749   const std::uint8_t b_data[] = {
750       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,
751       0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   127,
752       127, 127, 0,   0,   0,   127, 127, 127, 0,   0,   0,   255, 255, 255,
753       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,
754       0,   0,   0,   0,   0,   0,   0,   127, 127, 127, 0,   0,   0,   127,
755       127, 127, 0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127,
756       0,   0,   0,   0,   0,   0,   127, 127, 127, 127, 127, 127, 0,   0,
757       0,   127, 127, 127, 127, 127, 127, 127, 127, 127};
758   const int ldb = k;
759   int b_offset = -127;
760   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
761   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
762 
763   // 2 x 9, columnwise.
764   const std::uint8_t expected_c_data[] = {255, 255, 0,   0,   127, 159,
765                                           0,   64,  0,   64,  127, 159,
766                                           127, 127, 127, 127, 127, 127};
767   const int ldc = m;
768   int c_offset[] = {97155, 97346};
769   int c_mult_int[] = {2741, 2741};
770   const int c_shift = 21;
771 
772   const int c_count = m * n;
773   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
774   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
775                                                      ldc);
776   const OffsetColMap result_offset(c_offset, m);
777   const OffsetColMap result_mult_int(c_mult_int, m);
778   const int result_shift = c_shift;
779 
780   GemmContext gemm_context;
781   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
782       result_offset, result_mult_int, result_shift);
783   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
784                            DefaultL8R8BitDepthParams>(
785       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
786       output_pipeline);
787 
788   ResultStats stats;
789   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
790 
791   ResultStatsBounds bounds;
792   const bool good = CheckResultStatsBounds(stats, bounds);
793   printf("TestWithSmallDataPerChannelQuantization: %s\n",
794          good ? "PASS" : "FAIL");
795   ReportResultStats(stats, bounds);
796   Check(good);
797 }
798 
799 // Runs a larger set of hand-picked data for per-channel quantized data.
800 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
801 // image.  Right now, I have 7 different filters and 15 copies of the first
802 // filter to make sure NEON code path that processes 16 rows at a time is
803 // covered.
TestWithLargeDataPerChannelQuantization()804 void TestWithLargeDataPerChannelQuantization() {
805   const int m = 22;
806   const int n = 25;
807   const int k = 27;
808 
809   // 27 x 22, column-wise.
810   const std::uint8_t a_data[] = {
811       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
812       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
813       0,   0,   0,   0,   0,   0,   127, 127, 127, 255, 255, 255, 127, 127, 127,
814       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   127, 127, 127,
815       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
816       127, 127, 127, 0,   0,   0,   51,  51,  51,  51,  51,  51,  51,  51,  51,
817       0,   0,   0,   255, 255, 255, 0,   0,   0,   51,  51,  51,  51,  51,  51,
818       51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,  51,  51,  51,
819       255, 255, 255, 51,  51,  51,  51,  51,  51,  0,   0,   0,   51,  51,  51,
820       0,   0,   0,   64,  64,  64,  0,   0,   0,   64,  64,  64,  255, 255, 255,
821       64,  64,  64,  0,   0,   0,   64,  64,  64,  0,   0,   0,   36,  36,  36,
822       0,   0,   0,   36,  36,  36,  0,   0,   0,   255, 255, 255, 0,   0,   0,
823       36,  36,  36,  0,   0,   0,   36,  36,  36,  0,   0,   0,   0,   0,   0,
824       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
825       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
826       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
827       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
828       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
829       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
830       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
831       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
832       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
833       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
834       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
835       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
836       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
837       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
838       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
839       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
840       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
841       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
842       0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,
843       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
844       0,   0,   0,   255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,
845       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
846       255, 255, 255, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
847       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255,
848       0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
849       0,   0,   0,   0,   0,   0,   0,   0,   0,   255, 255, 255, 0,   0,   0,
850       0,   0,   0,   0,   0,   0,   0,   0,   0,
851   };
852   const int lda = k;
853   int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
854                     0, 0, 0, 0,   0,   0, 0,   0, 0, 0, 0};
855   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
856   const OffsetColMap lhs_offset(a_offset, m);
857 
858   // 27 x 25, column-wise.
859   const std::uint8_t b_data[] = {
860       127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
861       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
862       127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
863       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
864       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
865       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
866       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
867       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
868       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
869       119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
870       119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
871       127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
873       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
874       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
876       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
877       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
878       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
879       119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
880       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
881       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
882       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
883       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
885       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
886       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
887       119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
888       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
889       127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
890       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
891       119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
892       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893       136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
894       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
895       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
897       119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
898       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
899       119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
900       127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
901       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
902       127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
903       119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
904       127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
905       119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
906       127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
907       119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
908       127, 127, 127};
909   const int ldb = k;
910   int b_offset = -127;
911   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
912   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
913 
914   // 22 x 25, column-wise.
915   const std::uint8_t expected_c_data[] = {
916       7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,
917       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,
918       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
919       7,   37,  87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,
920       7,   7,   7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,
921       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
922       37,  67,  67,  39,  79,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
923       7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,
924       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
925       87,  87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
926       7,   7,   7,   7,   7,   7,   71,  87,  45,  41,  77,  7,   7,   7,   7,
927       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,
928       87,  7,   103, 7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
929       7,   7,   7,   7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,
930       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,
931       23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
932       7,   7,   7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,
933       7,   7,   7,   7,   7,   7,   7,   7,   7,   255, 135, 135, 255, 255, 143,
934       255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
935       255, 7,   71,  7,   45,  87,  41,  77,  7,   7,   7,   7,   7,   7,   7,
936       7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  7,   67,  87,  23,  91,
937       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
938       7,   37,  7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,
939       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,
940       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
941       7,   71,  87,  45,  41,  77,  7,   7,   7,   7,   7,   7,   7,   7,   7,
942       7,   7,   7,   7,   7,   7,   7,   7,   7,   87,  87,  7,   103, 7,   7,
943       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
944       7,   67,  87,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
945       7,   7,   7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,
946       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,
947       87,  67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
948       7,   7,   7,   7,   7,   7,   37,  87,  67,  23,  91,  7,   7,   7,   7,
949       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   37,  87,
950       67,  23,  91,  7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   7,
951       7,   7,   7,   7,   37,  37,  67,  67,  39,  79,  7,   7,   7,   7,   7,
952       7,   7,   7,   7,   7,   7,   7,   7,   7,   7,   99,  99,  99,  99,  99,
953       99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,
954       99,  99,  111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
955       111, 111, 111, 111, 111, 111, 111, 111, 111,
956   };
957   const int ldc = m;
958   int c_offset[] = {
959       6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
960       6477, 6477,  6477,  6477, 6477, 6477,  6477, 6477, 6477, 6477, 6477,
961   };
962   int c_mult_int[] = {
963       41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
964       41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
965       41121, 41121, 41121, 41121, 41121, 41121,
966   };
967   const int c_shift = 21;
968 
969   const int c_count = m * n;
970   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
971   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
972                                                      ldc);
973   const OffsetColMap result_offset(c_offset, m);
974   const OffsetColMap result_mult_int(c_mult_int, m);
975   const int result_shift = c_shift;
976 
977   GemmContext gemm_context;
978   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
979       result_offset, result_mult_int, result_shift);
980   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
981                            DefaultL8R8BitDepthParams>(
982       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
983       output_pipeline);
984 
985   ResultStats stats;
986   GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
987 
988   ResultStatsBounds bounds;
989   const bool good = CheckResultStatsBounds(stats, bounds);
990   printf("TestWithLargeDataPerChannelQuantization: %s\n",
991          good ? "PASS" : "FAIL");
992   ReportResultStats(stats, bounds);
993   Check(good);
994 }
995 
996 // Multithreading only activates when the result has more than 16 rows, and also
997 // (result rows) * (result cols) * depth >= 2 x 65 x 1024.  Size was selected
998 // to run in 3 threads.
999 //
1000 // Based on the following floating point data:
1001 //   LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1002 //     1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1003 //     0.01, 0.02 in last 16 rows.
1004 //   RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1005 //   Varying boundaries were used for each 16 rows block of LHS, to test for
1006 //     correct indexing into offsets.
1007 //   Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1008 //     5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1009 //     16 rows.
TestMultithreadedPerChannelQuantization()1010 void TestMultithreadedPerChannelQuantization() {
1011   const int m = 64;
1012   const int n = 20;
1013   const int k = 160;
1014 
1015   // LHS, m x k.
1016   const std::array<std::int32_t, 4> lhs_offsets_terse{{
1017       0, -51, -85, -109,
1018   }};
1019   assert(lhs_offsets_terse.size() * 16 == m);
1020   const std::array<std::uint8_t, 4> lhs_first_el{{
1021       128, 153, 170, 182,
1022   }};
1023   assert(lhs_first_el.size() * 16 == m);
1024 
1025   // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1026   std::vector<std::uint8_t> a_data(m * k, 0);
1027   for (int i = 0; i < m; ++i) {
1028     a_data[i * k] = lhs_first_el[i / 16];
1029     a_data[i * k + 1] = 255;
1030     for (int j = 2; j < k; ++j) {
1031       a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1032     }
1033   }
1034 
1035   const int lda = k;
1036   // Given values at [i / 16].
1037   std::vector<std::int32_t> a_offset(m, 0);
1038   for (int i = 0; i < m; ++i) {
1039     a_offset[i] = lhs_offsets_terse[i / 16];
1040   }
1041 
1042   MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1043   const OffsetColMap lhs_offset(&a_offset[0], m);
1044 
1045   // RHS, k x n.
1046   // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1047   std::vector<std::uint8_t> b_data(k * n, 0);
1048   b_data[0] = 128;
1049   b_data[1] = 255;
1050 
1051   const int ldb = k;
1052   std::int32_t b_offset = 0;
1053   MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1054   const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1055 
1056   // Result, m x n.
1057   // All zeros, except given values at (i / 16, 0).
1058   const std::array<std::uint8_t, 4> expected_c_terse{{
1059       142, 159, 182, 213,
1060   }};
1061   assert(expected_c_terse.size() * 16 == m);
1062   std::vector<std::uint8_t> expected_c_data(m * n, 0);
1063   for (int i = 0; i < m; ++i) {
1064     expected_c_data[i] = expected_c_terse[i / 16];
1065   }
1066 
1067   const int ldc = m;
1068   // All zeros.
1069   std::vector<std::int32_t> c_offset(m, 0);
1070   // Given values at [i / 16].
1071   const std::array<std::int32_t, 4> c_mult_int_terse{{
1072       3655, 5140, 7049, 9595,
1073   }};
1074   assert(c_mult_int_terse.size() * 16 == m);
1075   std::vector<std::int32_t> c_mult_int(m);
1076   for (int i = 0; i < m; ++i) {
1077     c_mult_int[i] = c_mult_int_terse[i / 16];
1078   }
1079 
1080   const int c_shift = 21;
1081 
1082   const int c_count = m * n;
1083   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1084   MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1085                                                      ldc);
1086   const OffsetColMap result_offset(&c_offset[0], m);
1087   const OffsetColMap result_mult_int(&c_mult_int[0], m);
1088   const int result_shift = c_shift;
1089 
1090   GemmContext gemm_context;
1091   auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1092       result_offset, result_mult_int, result_shift);
1093   GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1094                            DefaultL8R8BitDepthParams>(
1095       &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1096       output_pipeline);
1097 
1098   ResultStats stats;
1099   GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1100 
1101   ResultStatsBounds bounds;
1102   const bool good = CheckResultStatsBounds(stats, bounds);
1103   printf("TestMultithreadedPerChannelQuantization: %s\n",
1104          good ? "PASS" : "FAIL");
1105   ReportResultStats(stats, bounds);
1106   Check(good);
1107 }
1108 
1109 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1110 void TestWithSmallData() {
1111   const int m = 4;
1112   const int n = 2;
1113   const int k = 3;
1114   // Matrix A (LHS) is:
1115   // |  7 | 10 | 13 | 16 |
1116   // |  8 | 11 | 14 | 17 |
1117   // |  9 | 12 | 15 | 18 |
1118   const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1119   // Matrix B (RHS) is:
1120   // |  1 |  3 |  5 |
1121   // |  2 |  4 |  6 |
1122   const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1123   // Here are the results we expect, from hand calculations:
1124   // (1 * 7) + (3 * 8) + (5 * 9) = 76
1125   // (2 * 7) + (4 * 8) + (6 * 9) = 100
1126   // (1 * 10) + (3 * 11) + (5 * 12) = 103
1127   // (2 * 10) + (4 * 11) + (6 * 12) = 136
1128   // (1 * 13) + (3 * 14) + (5 * 15) = 130
1129   // (2 * 13) + (4 * 14) + (6 * 15) = 172
1130   // (1 * 16) + (3 * 17) + (5 * 18) = 157
1131   // (2 * 16) + (4 * 17) + (6 * 18) = 208
1132   // That means matrix C should be:
1133   // |  76 | 103 | 130 | 157 |
1134   // | 100 | 136 | 172 | 208 |
1135   const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1136 
1137   const int c_count = m * n;
1138   std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1139 
1140   const bool is_a_transposed = true;
1141   const bool is_b_transposed = true;
1142   const bool is_c_transposed = true;
1143   const int lda = k;
1144   const int ldb = n;
1145   const int ldc = n;
1146 
1147   const int a_offset = 0;
1148   const int b_offset = 0;
1149   const int c_offset = 0;
1150   const int c_mult = 1;
1151   const int c_shift = 0;
1152 
1153   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1154       is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1155       a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1156       c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1157 
1158   ResultStats stats;
1159   GetResultStats(output_data.get(), expected_data, c_count, &stats);
1160 
1161   ResultStatsBounds bounds;
1162   const bool good = CheckResultStatsBounds(stats, bounds);
1163   printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1164   ReportResultStats(stats, bounds);
1165   Check(good);
1166 }
1167 
1168 // This is the most realistic test of how we'll be using the low-precision GEMM
1169 // function in applications. It takes in large input matrices that have been
1170 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1171 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1172                       int tolerance_median, int tolerance_max) {
1173   std::unique_ptr<std::uint8_t[]> output_data(
1174       new std::uint8_t[test_data::c_count]);
1175   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1176       test_data::is_a_transposed, test_data::is_b_transposed,
1177       test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1178       test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1179       test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1180       test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1181 
1182   ResultStats stats;
1183   GetResultStats(output_data.get(), test_data::expected_c_data,
1184                  test_data::c_count, &stats);
1185 
1186   ResultStatsBounds bounds;
1187   if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1188     bounds.med_unsigned_diff = tolerance_median;
1189     bounds.max_unsigned_diff = tolerance_max;
1190     bounds.med_signed_diff = 0;
1191     bounds.mean_signed_diff = 0.2f;
1192   }
1193 
1194   const bool good = CheckResultStatsBounds(stats, bounds);
1195   printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1196          GetBitDepthName(BitDepth));
1197   ReportResultStats(stats, bounds);
1198   Check(good);
1199 }
1200 
1201 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1202 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1203                       int result_mult_int, int result_shift) {
1204   Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1205   Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1206   Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1207   MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1208   MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1209   const int lhs_offset = 12;
1210   const int rhs_offset = -34;
1211 
1212   // Test an empty pipeline, i.e. returning raw int32 accumulators.
1213   auto empty_pipeline = std::make_tuple();
1214   GemmContext context;
1215   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1216       &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1217       rhs_offset, empty_pipeline);
1218 
1219   for (int r = 0; r < rows; r++) {
1220     for (int c = 0; c < cols; c++) {
1221       std::int32_t expected = 0;
1222       for (int d = 0; d < depth; d++) {
1223         std::int32_t lhs_val =
1224             static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1225         std::int32_t rhs_val =
1226             static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1227         expected += lhs_val * rhs_val;
1228       }
1229       Check(expected == result_raw_int32(r, c));
1230     }
1231   }
1232 
1233   // Test a pipeline with only the quantize-down stage, still returning
1234   // unclamped (but scaled) int32's
1235   OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1236   quantize_down_stage.result_offset = result_offset;
1237   quantize_down_stage.result_mult_int = result_mult_int;
1238   quantize_down_stage.result_shift = result_shift;
1239   auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1240   Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1241   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1242       &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1243       lhs_offset, rhs_offset, quantize_down_pipeline);
1244 
1245   std::int64_t sum = 0;
1246   for (int r = 0; r < rows; r++) {
1247     for (int c = 0; c < cols; c++) {
1248       std::int32_t raw = result_raw_int32(r, c);
1249       std::int32_t expected = RoundingDivideByPOT(
1250           (raw + result_offset) * result_mult_int, result_shift);
1251       Check(expected == result_quantized_down_int32(r, c));
1252       sum += expected;
1253     }
1254   }
1255   std::int64_t avg = sum / (rows * cols);
1256   // Test that the average quantized-down value falls reasonably in the
1257   // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1258   // adjusted.
1259   Check(avg >= 64 && avg <= 192);
1260 
1261   // Test the familiar default pipeline consisting of quantize-down and
1262   // clamp-and-cast-to-uint8.
1263   OutputStageSaturatingCastToUint8 saturating_cast_stage;
1264   auto quantize_down_and_saturating_cast_pipeline =
1265       std::make_tuple(quantize_down_stage, saturating_cast_stage);
1266   Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1267                                                                           cols);
1268   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1269       &context, lhs.const_map(), rhs.const_map(),
1270       &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1271       quantize_down_and_saturating_cast_pipeline);
1272 
1273   for (int r = 0; r < rows; r++) {
1274     for (int c = 0; c < cols; c++) {
1275       std::int32_t quantized = result_quantized_down_int32(r, c);
1276       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1277       Check(expected == result_quantized_down_saturated_uint8(r, c));
1278     }
1279   }
1280 
1281   // Test a bias-addition with row-vector
1282   std::vector<std::int32_t> row_vector_data(cols);
1283   std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1284                                                                          500);
1285   for (int i = 0; i < cols; i++) {
1286     row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1287   }
1288   typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1289   RowVectorMap row_vector_map(row_vector_data.data(), cols);
1290   OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1291   row_bias_addition_stage.bias_vector = row_vector_map;
1292   auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1293   Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1294   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1295       &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1296       lhs_offset, rhs_offset, row_bias_addition_pipeline);
1297   for (int r = 0; r < rows; r++) {
1298     for (int c = 0; c < cols; c++) {
1299       std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1300       Check(expected == result_of_row_bias_addition(r, c));
1301     }
1302   }
1303 
1304   // Test a bias-addition with column-vector
1305   std::vector<std::int32_t> col_vector_data(rows);
1306   for (int i = 0; i < rows; i++) {
1307     col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1308   }
1309   typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1310   ColVectorMap col_vector_map(col_vector_data.data(), rows);
1311   OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1312   col_bias_addition_stage.bias_vector = col_vector_map;
1313   auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1314   Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1315   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1316       &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1317       lhs_offset, rhs_offset, col_bias_addition_pipeline);
1318   for (int r = 0; r < rows; r++) {
1319     for (int c = 0; c < cols; c++) {
1320       std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1321       Check(expected == result_of_col_bias_addition(r, c));
1322     }
1323   }
1324 
1325   // Test a clamp
1326   OutputStageClamp clamp_stage;
1327   // Determine min and max of raw int32 accumulators
1328   std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1329   std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1330   for (int r = 0; r < rows; r++) {
1331     for (int c = 0; c < cols; c++) {
1332       raw_min = std::min(raw_min, result_raw_int32(r, c));
1333       raw_max = std::max(raw_max, result_raw_int32(r, c));
1334     }
1335   }
1336   // Pick some interesting clamp min/max bounds
1337   clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1338   clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1339   assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1340          clamp_stage.max <= raw_max);
1341   auto clamp_pipeline = std::make_tuple(clamp_stage);
1342   Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1343   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1344       &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1345       rhs_offset, clamp_pipeline);
1346   for (int r = 0; r < rows; r++) {
1347     for (int c = 0; c < cols; c++) {
1348       std::int32_t raw = result_raw_int32(r, c);
1349       std::int32_t expected =
1350           std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1351       Check(expected == result_clamped(r, c));
1352     }
1353   }
1354 
1355   // Test tanh
1356   OutputStageTanh tanh_stage;
1357   const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1358   const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1359   tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1360   tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1361   auto tanh_pipeline = std::make_tuple(tanh_stage);
1362   Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1363   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1364       &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1365       rhs_offset, tanh_pipeline);
1366   for (int r = 0; r < rows; r++) {
1367     for (int c = 0; c < cols; c++) {
1368       std::int32_t raw = result_raw_int32(r, c);
1369       double real_input =
1370           double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1371       double expected = std::tanh(real_input);
1372       std::int32_t actual_int32 = result_tanh(r, c);
1373       double actual =
1374           double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1375       Check(std::abs(expected - actual) < 2e-4);
1376     }
1377   }
1378 
1379   // Test a pipeline with bias and clamp
1380   auto bias_clamp_pipeline =
1381       std::make_tuple(col_bias_addition_stage, clamp_stage);
1382   Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1383   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1384       &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1385       lhs_offset, rhs_offset, bias_clamp_pipeline);
1386   for (int r = 0; r < rows; r++) {
1387     for (int c = 0; c < cols; c++) {
1388       std::int32_t raw = result_raw_int32(r, c);
1389       std::int32_t biased = raw + col_vector_data[r];
1390       std::int32_t expected =
1391           std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1392       Check(expected == result_biased_clamped(r, c));
1393     }
1394   }
1395 
1396   // Test a full pipeline with bias and clamp and quantization down to 8bit
1397   // result
1398   auto bias_clamp_quantize_cast_pipeline =
1399       std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1400                       saturating_cast_stage);
1401   Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1402       rows, cols);
1403   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1404       &context, lhs.const_map(), rhs.const_map(),
1405       &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1406       bias_clamp_quantize_cast_pipeline);
1407   for (int r = 0; r < rows; r++) {
1408     for (int c = 0; c < cols; c++) {
1409       std::int32_t quantized = RoundingDivideByPOT(
1410           (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1411           result_shift);
1412       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1413       Check(expected == result_biased_clamped_quantized_casted(r, c));
1414     }
1415   }
1416 
1417   // Test a pipeline with the fixed-point-multiplier variant stage for the
1418   // quantizing down of 32bit accumulators.
1419   //
1420   // First, figure appropriate fixedpoint multiplier and shift values.
1421   std::int32_t result_fixedpoint_multiplier = result_mult_int;
1422   std::int32_t result_fixedpoint_shift = result_shift;
1423   Check(result_mult_int > 0);
1424   Check(result_shift > 0);
1425   result_fixedpoint_multiplier = result_mult_int;
1426   result_fixedpoint_shift = result_shift - 31;
1427   while (result_fixedpoint_multiplier < (1 << 30)) {
1428     result_fixedpoint_multiplier <<= 1;
1429     result_fixedpoint_shift++;
1430   }
1431   Check(result_fixedpoint_shift >= 0);
1432   // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1433   OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1434       quantize_down_by_fixedpoint_stage;
1435   quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1436       static_cast<std::int32_t>(
1437           round(static_cast<double>(result_offset * result_mult_int) /
1438                 (1 << result_shift)));
1439   quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1440       result_fixedpoint_multiplier;
1441   quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1442   auto quantize_down_by_fixedpoint_pipeline =
1443       std::make_tuple(quantize_down_by_fixedpoint_stage);
1444   Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1445       rows, cols);
1446   GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1447       &context, lhs.const_map(), rhs.const_map(),
1448       &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1449       quantize_down_by_fixedpoint_pipeline);
1450 
1451   std::vector<std::int32_t> diffs_caused_by_fixedpoint;
1452   for (int r = 0; r < rows; r++) {
1453     for (int c = 0; c < cols; c++) {
1454       const std::int32_t actual =
1455           result_quantized_down_by_fixedpoint_int32(r, c);
1456       const std::int32_t raw = result_raw_int32(r, c);
1457       const std::int32_t expected =
1458           quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1459           RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1460                                   raw, result_fixedpoint_multiplier),
1461                               result_fixedpoint_shift);
1462       Check(actual == expected);
1463     }
1464   }
1465 
1466   // Test the variant of the familiar default pipeline consisting of
1467   // quantize-down and
1468   // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1469   // downscaling.
1470   auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1471       std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1472   Matrix<std::uint8_t, ResultOrder>
1473       result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1474   GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1475       &context, lhs.const_map(), rhs.const_map(),
1476       &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1477       rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1478 
1479   for (int r = 0; r < rows; r++) {
1480     for (int c = 0; c < cols; c++) {
1481       std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1482       std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1483       Check(expected ==
1484             result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1485     }
1486   }
1487 
1488   printf("TestOutputStages: PASS with ResultOrder=%s\n",
1489          OrderName(ResultOrder));
1490 }
1491 
1492 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1493 template <typename BitDepthParams>
TestExhaustively()1494 void TestExhaustively() {
1495   GemmContext context;
1496 
1497   // Test the internal GEMM interfaces
1498   test_gemm<
1499       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1500                               std::uint8_t, BitDepthParams>>(&context);
1501 
1502   test_gemm<
1503       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1504                              std::uint8_t, BitDepthParams>>(&context);
1505 
1506   // Test the public GEMM interfaces
1507   test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1508 
1509   // Test GEMV cases (internal interfaces)
1510   test_gemv<
1511       SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1512                               std::uint8_t, BitDepthParams>>(&context);
1513 
1514   test_gemv<
1515       MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1516                              std::uint8_t, BitDepthParams>>(&context);
1517 
1518   // Test GEMV cases (public interfaces)
1519   test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1520 }
1521 
1522 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1523 void TestExhaustivelyEightBitIntGemm() {
1524   GemmContext context;
1525   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1526   test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1527   test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1528 }
1529 
TestKernels()1530 void TestKernels() {
1531   GemmContext context;
1532 
1533   // Test specific kernels with various different formats,
1534   // to exercises corner cases especially in the packing code.
1535   test_gemm_kernel<
1536       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1537                                    KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1538       &context);
1539 
1540   test_gemm_kernel<
1541       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1542                                    KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1543       &context);
1544 
1545   test_gemm_kernel<
1546       ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1547                                    KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1548       &context);
1549 
1550   test_gemm_kernel<ReferenceKernel<KernelFormat<
1551       KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1552       KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1553 
1554   test_gemm_kernel<ReferenceKernel<KernelFormat<
1555       KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1556       KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1557 
1558   test_gemm_kernel<ReferenceKernel<KernelFormat<
1559       KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1560       KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1561 
1562   test_gemm_kernel<ReferenceKernel<KernelFormat<
1563       KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1564       KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1565 
1566   test_gemm_kernel<ReferenceKernel<KernelFormat<
1567       KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1568       KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1569 
1570   test_gemm_kernel<ReferenceKernel<KernelFormat<
1571       KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1572       KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1573 }
1574 
1575 #endif  // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1576 
1577 template <typename BitDepthParams>
TestOutputStages()1578 void TestOutputStages() {
1579   // Test non-default output pipelines with various combinations of
1580   // output stages.
1581   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1582   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1583   TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1584                                                        14);
1585   TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1586                                                        14);
1587 }
1588 
test()1589 void test() {
1590 #ifdef GEMMLOWP_TEST_PROFILE
1591   RegisterCurrentThreadForProfiling();
1592   StartProfiling();
1593 #endif
1594 
1595   // Run a first quick test against hand-calculated data.
1596   TestWithSmallData();
1597 
1598 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1599   TestExhaustively<DefaultL8R8BitDepthParams>();
1600   TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1601   TestExhaustively<DefaultL7R5BitDepthParams>();  // legacy, same as L8R8
1602   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1603   TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1604   TestKernels();
1605 #endif
1606 
1607   // Run against actual data from a network evaluation.
1608   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1609   TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1610 
1611   // Test non-default output pipelines with various combinations of
1612   // output stages.
1613   TestOutputStages<DefaultL8R8BitDepthParams>();
1614   TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1615 
1616   // Test per channel quantization.
1617   TestWithSmallDataPerChannelQuantization();
1618   TestWithLargeDataPerChannelQuantization();
1619   TestMultithreadedPerChannelQuantization();
1620 #ifdef GEMMLOWP_TEST_PROFILE
1621   FinishProfiling();
1622 #endif
1623 
1624   std::cerr << "All tests passed." << std::endl;
1625 
1626   // We have been testing the eight_bit_int_gemm, so we should free its
1627   // persistent
1628   // resources now to avoid having leak-checking tools report leaks.
1629   eight_bit_int_gemm::FreePersistentResources();
1630 }
1631 
1632 }  // end namespace gemmlowp
1633 
1634 // For iOS, we need to define our own main(), so skip it here.
1635 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1636 int main() { gemmlowp::test(); }
1637 #endif
1638