1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test.h"
16
17 #include <unistd.h>
18 #include <array>
19 #include <cstdint>
20 #include <cstdlib>
21 #include <ctime>
22 #include <iostream>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 #ifdef __APPLE__
27 #include <TargetConditionals.h>
28 #endif
29
30 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
31 #include "../internal/kernel_reference.h"
32 #include "test_data.h"
33
34 namespace gemmlowp {
35
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)36 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
37 bool transpose_c, int m, int n, int k,
38 const std::uint8_t* a, std::int32_t a_offset,
39 int lda, const std::uint8_t* b,
40 std::int32_t b_offset, int ldb, std::uint8_t* c,
41 std::int32_t c_offset, std::int32_t c_mult_int,
42 std::int32_t c_shift, int ldc) {
43 ScopedProfilingLabel("ReferenceEightBitIntGemm");
44 assert((c_shift >= 0) && (c_shift <= 32));
45
46 assert(a != nullptr);
47 assert(b != nullptr);
48 assert(c != nullptr);
49
50 int a_i_stride;
51 int a_l_stride;
52 if (transpose_a) {
53 a_i_stride = lda;
54 a_l_stride = 1;
55 } else {
56 a_i_stride = 1;
57 a_l_stride = lda;
58 }
59 int b_j_stride;
60 int b_l_stride;
61 if (transpose_b) {
62 b_j_stride = 1;
63 b_l_stride = ldb;
64 } else {
65 b_j_stride = ldb;
66 b_l_stride = 1;
67 }
68 int c_i_stride;
69 int c_j_stride;
70 if (transpose_c) {
71 c_i_stride = ldc;
72 c_j_stride = 1;
73 } else {
74 c_i_stride = 1;
75 c_j_stride = ldc;
76 }
77 int i, j, l;
78
79 const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
80
81 for (j = 0; j < n; j++) {
82 for (i = 0; i < m; i++) {
83 std::int32_t total = 0;
84 for (l = 0; l < k; l++) {
85 const int a_index = i * a_i_stride + l * a_l_stride;
86 const std::uint8_t a_as_byte = a[a_index];
87 const std::int32_t a_as_int =
88 static_cast<std::int32_t>(a_as_byte) + a_offset;
89 const int b_index = j * b_j_stride + l * b_l_stride;
90 const std::uint8_t b_as_byte = b[b_index];
91 const std::int32_t b_as_int =
92 static_cast<std::int32_t>(b_as_byte) + b_offset;
93 const std::int32_t mult_as_int = a_as_int * b_as_int;
94 total += mult_as_int;
95 }
96 std::int32_t output =
97 (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
98 if (output > 255) {
99 output = 255;
100 }
101 if (output < 0) {
102 output = 0;
103 }
104 const int c_index = i * c_i_stride + j * c_j_stride;
105 c[c_index] = static_cast<std::uint8_t>(output);
106 }
107 }
108 }
109
110 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
111 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
112 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
113 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
114
115 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
116 // interface, so we can use the same testing code to test all of them
117
118 template <typename Kernel, typename Scalar, typename tBitDepthParams>
119 struct SingleThreadGemmWrapper {
120 typedef tBitDepthParams BitDepthParams;
121
Namegemmlowp::SingleThreadGemmWrapper122 static const char* Name() {
123 static char buf[256];
124 snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
125 return buf;
126 }
127
128 typedef SingleThreadGemmContext Context;
129
130 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper131 static bool Gemm(Context* context,
132 const MatrixMap<const Scalar, LhsOrder>& lhs,
133 const MatrixMap<const Scalar, RhsOrder>& rhs,
134 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
135 int rhs_offset, int result_offset, int result_mult_int,
136 int result_shift) {
137 ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
138 const int rows = lhs.rows();
139 const int cols = rhs.cols();
140 if (rows < cols) {
141 // SingleThreadGemm is never called with rows < cols.
142 // That case is handled earlier.
143 return false;
144 }
145 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
146 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
147 SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
148 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
149 OffsetRowDup>(
150 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
151 rhs_offset_vector,
152 MakeStandardOutputPipeline(result_offset, result_mult_int,
153 result_shift));
154 return true;
155 }
156 };
157
158 template <typename Kernel, typename Scalar, typename tBitDepthParams>
159 struct MultiThreadGemmWrapper {
160 typedef tBitDepthParams BitDepthParams;
161
Namegemmlowp::MultiThreadGemmWrapper162 static const char* Name() {
163 static char buf[256];
164 snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
165 return buf;
166 }
167
168 typedef MultiThreadGemmContext Context;
169
170 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper171 static bool Gemm(Context* context,
172 const MatrixMap<const Scalar, LhsOrder>& lhs,
173 const MatrixMap<const Scalar, RhsOrder>& rhs,
174 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
175 int rhs_offset, int result_offset, int result_mult_int,
176 int result_shift) {
177 ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
178 context->set_max_num_threads(0);
179 const int rows = lhs.rows();
180 const int cols = rhs.cols();
181 if (rows < cols) {
182 // SingleThreadGemm is never called with rows < cols.
183 // That case is handled earlier.
184 return false;
185 }
186 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
187 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
188 MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
189 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
190 OffsetRowDup>(
191 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
192 rhs_offset_vector,
193 MakeStandardOutputPipeline(result_offset, result_mult_int,
194 result_shift));
195 return true;
196 }
197 };
198
199 template <typename Scalar, typename tBitDepthParams>
200 struct PublicGemmWrapper {
201 typedef tBitDepthParams BitDepthParams;
202
Namegemmlowp::PublicGemmWrapper203 static const char* Name() { return "public Gemm"; }
204
205 typedef GemmContext Context;
206
207 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper208 static bool Gemm(Context* context,
209 const MatrixMap<const Scalar, LhsOrder>& lhs,
210 const MatrixMap<const Scalar, RhsOrder>& rhs,
211 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
212 int rhs_offset, int result_offset, int result_mult_int,
213 int result_shift) {
214 ScopedProfilingLabel("PublicGemmWrapper::Gemm");
215 gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
216 ResultOrder>(context, lhs, rhs, result, lhs_offset,
217 rhs_offset, result_offset, result_mult_int,
218 result_shift);
219 return true;
220 }
221 };
222
223 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
224 struct BitDepthParamsForSettings {};
225
226 template <>
227 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
228 : DefaultL8R8BitDepthParams {};
229
230 template <>
231 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
232 : DefaultL7R5BitDepthParams {};
233
234 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
235 struct EightBitIntGemmWrapper {
236 typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
237
Namegemmlowp::EightBitIntGemmWrapper238 static const char* Name() { return "EightBitIntGemm"; }
239
240 typedef void Context;
241
242 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper243 static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
244 const MatrixMap<const Scalar, RhsOrder>& rhs,
245 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
246 int rhs_offset, int result_offset, int result_mult_int,
247 int result_shift) {
248 ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
249 const bool transpose_c = ResultOrder == MapOrder::RowMajor;
250 const bool transpose_a = LhsOrder == MapOrder::RowMajor;
251 const bool transpose_b = RhsOrder == MapOrder::RowMajor;
252 eight_bit_int_gemm::EightBitIntGemm(
253 transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
254 lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
255 rhs_offset, rhs.stride(), result->data(), result_offset,
256 result_mult_int, result_shift, result->stride(), BitDepth);
257 return true;
258 }
259 };
260
261 template <typename Scalar>
262 struct ReferenceEightBitIntGemmWrapper {
263 typedef DefaultL8R8BitDepthParams BitDepthParams;
264
Namegemmlowp::ReferenceEightBitIntGemmWrapper265 static const char* Name() { return "ReferenceEightBitIntGemm"; }
266
267 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper268 static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
269 const MatrixMap<const Scalar, LhsOrder>& lhs,
270 const MatrixMap<const Scalar, RhsOrder>& rhs,
271 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
272 int rhs_offset, int result_offset, int result_mult_int,
273 int result_shift) {
274 ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
275 ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
276 rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
277 lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
278 result->data(), result_offset, result_mult_int,
279 result_shift, result->stride());
280 return true;
281 }
282 };
283
OrderName(MapOrder order)284 const char* OrderName(MapOrder order) {
285 return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
286 }
287
288 struct ResultStats {
ResultStatsgemmlowp::ResultStats289 ResultStats()
290 : count(0),
291 med_val(0),
292 mean_signed_diff(0),
293 med_signed_diff(0),
294 med_unsigned_diff(0),
295 max_unsigned_diff(0) {}
296
297 int count;
298 int med_val;
299 float mean_signed_diff;
300 int med_signed_diff;
301 int med_unsigned_diff;
302 int max_unsigned_diff;
303
304 std::vector<int> count_diff_by_pot_slice;
305 };
306
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)307 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
308 size_t count, ResultStats* stats) {
309 ScopedProfilingLabel("GetResultStats");
310 std::vector<std::uint8_t> results;
311 std::vector<std::int16_t> signed_diffs;
312 std::vector<std::uint8_t> unsigned_diffs;
313 std::int64_t signed_diffs_sum = 0;
314 for (size_t i = 0; i < count; i++) {
315 results.push_back(actual[i]);
316 std::int16_t signed_diff = actual[i] - expected[i];
317 signed_diffs.push_back(signed_diff);
318 unsigned_diffs.push_back(std::abs(signed_diff));
319 signed_diffs_sum += signed_diff;
320 }
321
322 std::sort(results.begin(), results.end());
323 std::sort(signed_diffs.begin(), signed_diffs.end());
324 std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
325
326 const size_t middle = count / 2;
327
328 stats->count = count;
329 stats->med_val = results[middle];
330 stats->mean_signed_diff = float(signed_diffs_sum) / count;
331 stats->med_signed_diff = signed_diffs[middle];
332 stats->med_unsigned_diff = unsigned_diffs[middle];
333 stats->max_unsigned_diff = unsigned_diffs.back();
334
335 // Size 9 for 9 different POT values: 2^0, ..., 2^8
336 stats->count_diff_by_pot_slice.resize(9);
337 auto cur = unsigned_diffs.begin();
338 size_t checksum = 0;
339 for (int exponent = 0; exponent < 9; exponent++) {
340 int pot = 1 << exponent;
341 auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
342 checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
343 cur = next;
344 }
345 assert(checksum == count);
346 }
347
348 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds349 ResultStatsBounds()
350 : mean_signed_diff(0),
351 med_signed_diff(0),
352 med_unsigned_diff(0),
353 max_unsigned_diff(0) {}
354
355 float mean_signed_diff;
356 int med_signed_diff;
357 int med_unsigned_diff;
358 int max_unsigned_diff;
359 };
360
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)361 bool CheckResultStatsBounds(const ResultStats& stats,
362 const ResultStatsBounds& bounds) {
363 return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
364 stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
365 std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
366 std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
367 }
368
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)369 void ReportResultStats(const ResultStats& stats,
370 const ResultStatsBounds& bounds) {
371 printf(" number of matrix entries: %d\n", stats.count);
372 printf(" median value: %d\n", stats.med_val);
373 printf(" median unsigned diff: %d (tolerating %d)\n",
374 stats.med_unsigned_diff, bounds.med_unsigned_diff);
375 printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
376 bounds.max_unsigned_diff);
377 printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
378 bounds.med_signed_diff);
379 printf(" mean signed diff: %.3g (tolerating %.3g)\n",
380 stats.mean_signed_diff, bounds.mean_signed_diff);
381
382 printf("No error: %.2f %% of entries\n",
383 100.f * stats.count_diff_by_pot_slice[0] / stats.count);
384 for (int exponent = 1; exponent < 9; exponent++) {
385 printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
386 (1 << exponent) - 1,
387 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
388 }
389 }
390
391 // Our approach to choosing result_shift values for testing, is bisection.
392 // This function takes an interval, [result_shift_min .. result_shift_max].
393 // If too much saturation occurred in either direction, it bisects accordingly,
394 // recursing until the interval contains only one value.
395 // The primary reason why we prefer this over computing optimal shift values,
396 // is that we actually want to exercise some saturation, as there is nontrivial
397 // code handling that in gemmlowp.
398 // Secondarily, this is faster than computing optimal shifts, since in 90% of
399 // cases the first-tried shift value 16 turns out to be good enough.
400 template <typename GemmWrapper, typename LhsType, typename RhsType,
401 typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)402 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
403 const RhsType& rhs, ResultType* result, int lhs_offset,
404 int rhs_offset, int result_offset, int result_mult_int,
405 int result_shift_min, int result_shift_max) {
406 const int rows = lhs.rows();
407 const int cols = rhs.cols();
408 Check(lhs.cols() == rhs.rows());
409 const int depth = lhs.cols();
410
411 const int result_shift = (result_shift_min + result_shift_max) / 2;
412
413 if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
414 &result->map(), lhs_offset, rhs_offset, result_offset,
415 result_mult_int, result_shift)) {
416 // Internal GEMM functions are not required to handle all cases
417 // (e.g. rows < cols) as these are supposed to have been handled
418 // ahead of them. Their test wrappers return false in that case.
419 return;
420 }
421
422 typedef typename ResultType::Scalar Scalar;
423 static const MapOrder kLhsOrder = LhsType::kOrder;
424 static const MapOrder kRhsOrder = RhsType::kOrder;
425 static const MapOrder kResultOrder = ResultType::kOrder;
426 ResultType ref_result(rows, cols);
427 const bool transpose_c = kResultOrder == MapOrder::RowMajor;
428 const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
429 const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
430 ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
431 transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
432 &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
433 result_shift);
434
435 typedef typename GemmWrapper::BitDepthParams BitDepthParams;
436
437 ResultStats stats;
438 GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
439
440 // Adjust shifts until we get meaningful results
441 int new_result_shift_min = result_shift_min;
442 int new_result_shift_max = result_shift_max;
443 bool retry = false;
444
445 if (stats.med_val < 32) {
446 new_result_shift_max = (result_shift_min + result_shift_max) / 2;
447 retry = true;
448 }
449
450 if (stats.med_val > 224) {
451 new_result_shift_min = (result_shift_min + result_shift_max) / 2;
452 retry = true;
453 }
454
455 if (retry) {
456 if (result_shift_min != result_shift_max) {
457 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
458 rhs_offset, result_offset, result_mult_int,
459 new_result_shift_min, new_result_shift_max);
460 }
461 return;
462 }
463
464 ResultStatsBounds bounds;
465
466 // Check results
467 const bool good = CheckResultStatsBounds(stats, bounds);
468
469 printf(
470 "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
471 good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
472 OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
473 lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
474
475 if (!good) {
476 ReportResultStats(stats, bounds);
477
478 int bad_coeffs_printed = 0;
479 for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
480 for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
481 if (ref_result(r, c) != (*result)(r, c)) {
482 printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
483 ref_result(r, c), (*result)(r, c));
484 bad_coeffs_printed++;
485 }
486 }
487 }
488 }
489
490 Check(good);
491 }
492
493 template <typename GemmWrapper, typename LhsType, typename RhsType,
494 typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)495 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
496 const RhsType& rhs, ResultType* result, int lhs_offset,
497 int rhs_offset, int result_offset, int result_mult_int) {
498 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
499 result_offset, result_mult_int, 0, 32);
500 }
501
502 enum class WhatParamsToTest {
503 All,
504 OnlyGenericCase,
505 };
506
507 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
508 MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)509 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
510 int cols, WhatParamsToTest params_to_test) {
511 typedef std::uint8_t Scalar;
512 typedef Matrix<Scalar, LhsOrder> LhsType;
513 using BitDepthParams = typename GemmWrapper::BitDepthParams;
514 LhsType lhs(rows, depth);
515 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
516 typedef Matrix<Scalar, RhsOrder> RhsType;
517 RhsType rhs(depth, cols);
518 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
519 typedef Matrix<Scalar, ResultOrder> ResultType;
520 ResultType result(rows, cols);
521 MakeZero(&result);
522
523 if (params_to_test == WhatParamsToTest::All) {
524 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
525 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
526 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
527 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
528 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
529 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
530 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
531 }
532 test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
533 }
534
535 enum class WhatOrdersToTest { All, OnlyRCC };
536
537 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)538 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
539 int cols, WhatParamsToTest params_to_test,
540 WhatOrdersToTest orders_to_test) {
541 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \
542 do { \
543 test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
544 MapOrder::ResultOrder>(context, rows, depth, cols, \
545 params_to_test); \
546 } while (false)
547
548 if (orders_to_test == WhatOrdersToTest::All) {
549 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
550 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
551 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
552 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
553
554 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
555 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
556 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
557 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
558 } else {
559 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
560 }
561
562 #undef GEMMLOWP_ONE_TEST
563 }
564
565 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)566 void test_gemm_kernel(MultiThreadGemmContext* context) {
567 typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
568 DefaultL8R8BitDepthParams>
569 GemmWrapper;
570 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
571 WhatOrdersToTest::OnlyRCC);
572 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
573 WhatOrdersToTest::OnlyRCC);
574 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
575 WhatOrdersToTest::OnlyRCC);
576 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
577 WhatOrdersToTest::OnlyRCC);
578 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
579 WhatOrdersToTest::OnlyRCC);
580 test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
581 WhatOrdersToTest::OnlyRCC);
582 test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
583 WhatOrdersToTest::OnlyRCC);
584 test_gemm<GemmWrapper>(context, 200, 200, 200,
585 WhatParamsToTest::OnlyGenericCase,
586 WhatOrdersToTest::All);
587 test_gemm<GemmWrapper>(context, 50, 5000, 50,
588 WhatParamsToTest::OnlyGenericCase,
589 WhatOrdersToTest::OnlyRCC);
590 }
591
592 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)593 void test_gemm(typename GemmWrapper::Context* context) {
594 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
595 WhatOrdersToTest::OnlyRCC);
596 test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
597 WhatOrdersToTest::OnlyRCC);
598 test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
599 WhatOrdersToTest::OnlyRCC);
600 test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
601 WhatOrdersToTest::OnlyRCC);
602 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
603 WhatOrdersToTest::OnlyRCC);
604 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
605 WhatOrdersToTest::OnlyRCC);
606 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
607 WhatOrdersToTest::OnlyRCC);
608 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
609 WhatOrdersToTest::OnlyRCC);
610 test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
611 WhatOrdersToTest::OnlyRCC);
612 test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
613 WhatOrdersToTest::OnlyRCC);
614 test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
615 WhatOrdersToTest::OnlyRCC);
616 test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
617 WhatOrdersToTest::OnlyRCC);
618 test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
619 WhatOrdersToTest::All);
620 test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
621 WhatOrdersToTest::OnlyRCC);
622 test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
623 WhatOrdersToTest::OnlyRCC);
624 test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
625 WhatOrdersToTest::All);
626 test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
627 WhatOrdersToTest::OnlyRCC);
628
629 test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
630 WhatOrdersToTest::OnlyRCC);
631 test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
632 WhatOrdersToTest::OnlyRCC);
633 test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
634 WhatOrdersToTest::OnlyRCC);
635 test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
636 WhatOrdersToTest::OnlyRCC);
637 test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
638 WhatOrdersToTest::OnlyRCC);
639 test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
640 WhatOrdersToTest::OnlyRCC);
641
642 test_gemm<GemmWrapper>(context, 512, 512, 512,
643 WhatParamsToTest::OnlyGenericCase,
644 WhatOrdersToTest::OnlyRCC);
645 test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
646 WhatParamsToTest::OnlyGenericCase,
647 WhatOrdersToTest::OnlyRCC);
648 test_gemm<GemmWrapper>(context, 567, 2345, 123,
649 WhatParamsToTest::OnlyGenericCase,
650 WhatOrdersToTest::OnlyRCC);
651 test_gemm<GemmWrapper>(context, 100, 5000, 100,
652 WhatParamsToTest::OnlyGenericCase,
653 WhatOrdersToTest::OnlyRCC);
654 test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
655 WhatOrdersToTest::OnlyRCC);
656 test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
657 WhatOrdersToTest::OnlyRCC);
658 test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
659 WhatOrdersToTest::OnlyRCC);
660 test_gemm<GemmWrapper>(context, 1, 1000, 1000,
661 WhatParamsToTest::OnlyGenericCase,
662 WhatOrdersToTest::OnlyRCC);
663 test_gemm<GemmWrapper>(context, 1000, 1, 1000,
664 WhatParamsToTest::OnlyGenericCase,
665 WhatOrdersToTest::OnlyRCC);
666 test_gemm<GemmWrapper>(context, 1000, 1000, 1,
667 WhatParamsToTest::OnlyGenericCase,
668 WhatOrdersToTest::OnlyRCC);
669 test_gemm<GemmWrapper>(context, 777, 3456, 1,
670 WhatParamsToTest::OnlyGenericCase,
671 WhatOrdersToTest::OnlyRCC);
672 test_gemm<GemmWrapper>(context, 4567, 555, 1,
673 WhatParamsToTest::OnlyGenericCase,
674 WhatOrdersToTest::OnlyRCC);
675
676 // Test all storage orders
677 test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
678 WhatOrdersToTest::All);
679 test_gemm<GemmWrapper>(context, 300, 400, 500,
680 WhatParamsToTest::OnlyGenericCase,
681 WhatOrdersToTest::All);
682 }
683
684 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)685 void test_gemv(typename GemmWrapper::Context* context) {
686 test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
687 WhatOrdersToTest::OnlyRCC);
688 test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
689 WhatOrdersToTest::OnlyRCC);
690 test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
691 WhatOrdersToTest::OnlyRCC);
692 test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
693 WhatOrdersToTest::OnlyRCC);
694 test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
695 WhatOrdersToTest::OnlyRCC);
696 test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
697 WhatOrdersToTest::OnlyRCC);
698 test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
699 WhatOrdersToTest::OnlyRCC);
700 test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
701 WhatOrdersToTest::OnlyRCC);
702 test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
703 WhatOrdersToTest::OnlyRCC);
704 test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
705 WhatOrdersToTest::OnlyRCC);
706 test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
707 WhatOrdersToTest::OnlyRCC);
708 test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
709 WhatOrdersToTest::OnlyRCC);
710
711 // Test all storage orders
712 test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
713 WhatOrdersToTest::All);
714 test_gemm<GemmWrapper>(context, 300, 400, 1,
715 WhatParamsToTest::OnlyGenericCase,
716 WhatOrdersToTest::All);
717 }
718
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)719 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
720 switch (b) {
721 case eight_bit_int_gemm::BitDepthSetting::A8B8:
722 return "Lhs: 8 bit, Rhs: 8 bit";
723 case eight_bit_int_gemm::BitDepthSetting::A5B7:
724 return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
725 default:
726 abort();
727 return nullptr;
728 }
729 }
730
731 // Runs a small set of hand-picked data for per-channel quantized data.
732 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
733 // image.
TestWithSmallDataPerChannelQuantization()734 void TestWithSmallDataPerChannelQuantization() {
735 const int m = 2;
736 const int n = 9;
737 const int k = 12;
738
739 // 12 x 2, columnwise.
740 const std::uint8_t a_data[] = {0, 0, 0, 0, 0, 0, 0, 0,
741 0, 255, 255, 255, 64, 64, 64, 64,
742 64, 64, 0, 0, 0, 255, 255, 255};
743 const int lda = k;
744 int a_offset[] = {0, -64};
745 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
746 const OffsetColMap lhs_offset(a_offset, m);
747
748 // 12 x 9, columnwise.
749 const std::uint8_t b_data[] = {
750 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0,
751 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 127,
752 127, 127, 0, 0, 0, 127, 127, 127, 0, 0, 0, 255, 255, 255,
753 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0,
754 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 0, 0, 0, 127,
755 127, 127, 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127,
756 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127, 0, 0,
757 0, 127, 127, 127, 127, 127, 127, 127, 127, 127};
758 const int ldb = k;
759 int b_offset = -127;
760 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
761 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
762
763 // 2 x 9, columnwise.
764 const std::uint8_t expected_c_data[] = {255, 255, 0, 0, 127, 159,
765 0, 64, 0, 64, 127, 159,
766 127, 127, 127, 127, 127, 127};
767 const int ldc = m;
768 int c_offset[] = {97155, 97346};
769 int c_mult_int[] = {2741, 2741};
770 const int c_shift = 21;
771
772 const int c_count = m * n;
773 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
774 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
775 ldc);
776 const OffsetColMap result_offset(c_offset, m);
777 const OffsetColMap result_mult_int(c_mult_int, m);
778 const int result_shift = c_shift;
779
780 GemmContext gemm_context;
781 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
782 result_offset, result_mult_int, result_shift);
783 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
784 DefaultL8R8BitDepthParams>(
785 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
786 output_pipeline);
787
788 ResultStats stats;
789 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
790
791 ResultStatsBounds bounds;
792 const bool good = CheckResultStatsBounds(stats, bounds);
793 printf("TestWithSmallDataPerChannelQuantization: %s\n",
794 good ? "PASS" : "FAIL");
795 ReportResultStats(stats, bounds);
796 Check(good);
797 }
798
799 // Runs a larger set of hand-picked data for per-channel quantized data.
800 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
801 // image. Right now, I have 7 different filters and 15 copies of the first
802 // filter to make sure NEON code path that processes 16 rows at a time is
803 // covered.
TestWithLargeDataPerChannelQuantization()804 void TestWithLargeDataPerChannelQuantization() {
805 const int m = 22;
806 const int n = 25;
807 const int k = 27;
808
809 // 27 x 22, column-wise.
810 const std::uint8_t a_data[] = {
811 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
812 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
813 0, 0, 0, 0, 0, 0, 127, 127, 127, 255, 255, 255, 127, 127, 127,
814 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127,
815 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
816 127, 127, 127, 0, 0, 0, 51, 51, 51, 51, 51, 51, 51, 51, 51,
817 0, 0, 0, 255, 255, 255, 0, 0, 0, 51, 51, 51, 51, 51, 51,
818 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51, 51, 51, 51,
819 255, 255, 255, 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51,
820 0, 0, 0, 64, 64, 64, 0, 0, 0, 64, 64, 64, 255, 255, 255,
821 64, 64, 64, 0, 0, 0, 64, 64, 64, 0, 0, 0, 36, 36, 36,
822 0, 0, 0, 36, 36, 36, 0, 0, 0, 255, 255, 255, 0, 0, 0,
823 36, 36, 36, 0, 0, 0, 36, 36, 36, 0, 0, 0, 0, 0, 0,
824 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
825 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
826 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
827 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
828 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
829 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
830 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
831 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
832 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
833 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
834 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
835 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
836 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
837 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
838 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
839 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
840 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
841 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
842 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
843 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
844 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
845 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
846 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
847 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
848 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
849 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
850 0, 0, 0, 0, 0, 0, 0, 0, 0,
851 };
852 const int lda = k;
853 int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
854 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
855 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
856 const OffsetColMap lhs_offset(a_offset, m);
857
858 // 27 x 25, column-wise.
859 const std::uint8_t b_data[] = {
860 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
861 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
862 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
863 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
864 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
865 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
866 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
867 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
868 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
869 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
870 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
871 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
873 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
874 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
876 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
877 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
878 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
879 119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
880 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
881 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
882 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
883 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
885 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
886 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
887 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
888 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
889 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
890 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
891 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
892 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
894 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
895 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
897 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
898 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
899 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
900 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
901 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
902 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
903 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
904 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
905 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
906 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
907 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
908 127, 127, 127};
909 const int ldb = k;
910 int b_offset = -127;
911 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
912 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
913
914 // 22 x 25, column-wise.
915 const std::uint8_t expected_c_data[] = {
916 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
917 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7,
918 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
919 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7,
920 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7,
921 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
922 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
923 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7,
924 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
925 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
926 7, 7, 7, 7, 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7,
927 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87,
928 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
929 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7,
930 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87,
931 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
932 7, 7, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7,
933 7, 7, 7, 7, 7, 7, 7, 7, 7, 255, 135, 135, 255, 255, 143,
934 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
935 255, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7,
936 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91,
937 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
938 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
939 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7,
940 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
941 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 7,
942 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7, 7,
943 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
944 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
945 7, 7, 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7,
946 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
947 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
948 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7,
949 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87,
950 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
951 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7,
952 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 99, 99, 99, 99, 99,
953 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
954 99, 99, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
955 111, 111, 111, 111, 111, 111, 111, 111, 111,
956 };
957 const int ldc = m;
958 int c_offset[] = {
959 6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
960 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477,
961 };
962 int c_mult_int[] = {
963 41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
964 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
965 41121, 41121, 41121, 41121, 41121, 41121,
966 };
967 const int c_shift = 21;
968
969 const int c_count = m * n;
970 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
971 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
972 ldc);
973 const OffsetColMap result_offset(c_offset, m);
974 const OffsetColMap result_mult_int(c_mult_int, m);
975 const int result_shift = c_shift;
976
977 GemmContext gemm_context;
978 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
979 result_offset, result_mult_int, result_shift);
980 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
981 DefaultL8R8BitDepthParams>(
982 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
983 output_pipeline);
984
985 ResultStats stats;
986 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
987
988 ResultStatsBounds bounds;
989 const bool good = CheckResultStatsBounds(stats, bounds);
990 printf("TestWithLargeDataPerChannelQuantization: %s\n",
991 good ? "PASS" : "FAIL");
992 ReportResultStats(stats, bounds);
993 Check(good);
994 }
995
996 // Multithreading only activates when the result has more than 16 rows, and also
997 // (result rows) * (result cols) * depth >= 2 x 65 x 1024. Size was selected
998 // to run in 3 threads.
999 //
1000 // Based on the following floating point data:
1001 // LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1002 // 1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1003 // 0.01, 0.02 in last 16 rows.
1004 // RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1005 // Varying boundaries were used for each 16 rows block of LHS, to test for
1006 // correct indexing into offsets.
1007 // Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1008 // 5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1009 // 16 rows.
TestMultithreadedPerChannelQuantization()1010 void TestMultithreadedPerChannelQuantization() {
1011 const int m = 64;
1012 const int n = 20;
1013 const int k = 160;
1014
1015 // LHS, m x k.
1016 const std::array<std::int32_t, 4> lhs_offsets_terse{{
1017 0, -51, -85, -109,
1018 }};
1019 assert(lhs_offsets_terse.size() * 16 == m);
1020 const std::array<std::uint8_t, 4> lhs_first_el{{
1021 128, 153, 170, 182,
1022 }};
1023 assert(lhs_first_el.size() * 16 == m);
1024
1025 // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1026 std::vector<std::uint8_t> a_data(m * k, 0);
1027 for (int i = 0; i < m; ++i) {
1028 a_data[i * k] = lhs_first_el[i / 16];
1029 a_data[i * k + 1] = 255;
1030 for (int j = 2; j < k; ++j) {
1031 a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1032 }
1033 }
1034
1035 const int lda = k;
1036 // Given values at [i / 16].
1037 std::vector<std::int32_t> a_offset(m, 0);
1038 for (int i = 0; i < m; ++i) {
1039 a_offset[i] = lhs_offsets_terse[i / 16];
1040 }
1041
1042 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1043 const OffsetColMap lhs_offset(&a_offset[0], m);
1044
1045 // RHS, k x n.
1046 // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1047 std::vector<std::uint8_t> b_data(k * n, 0);
1048 b_data[0] = 128;
1049 b_data[1] = 255;
1050
1051 const int ldb = k;
1052 std::int32_t b_offset = 0;
1053 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1054 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1055
1056 // Result, m x n.
1057 // All zeros, except given values at (i / 16, 0).
1058 const std::array<std::uint8_t, 4> expected_c_terse{{
1059 142, 159, 182, 213,
1060 }};
1061 assert(expected_c_terse.size() * 16 == m);
1062 std::vector<std::uint8_t> expected_c_data(m * n, 0);
1063 for (int i = 0; i < m; ++i) {
1064 expected_c_data[i] = expected_c_terse[i / 16];
1065 }
1066
1067 const int ldc = m;
1068 // All zeros.
1069 std::vector<std::int32_t> c_offset(m, 0);
1070 // Given values at [i / 16].
1071 const std::array<std::int32_t, 4> c_mult_int_terse{{
1072 3655, 5140, 7049, 9595,
1073 }};
1074 assert(c_mult_int_terse.size() * 16 == m);
1075 std::vector<std::int32_t> c_mult_int(m);
1076 for (int i = 0; i < m; ++i) {
1077 c_mult_int[i] = c_mult_int_terse[i / 16];
1078 }
1079
1080 const int c_shift = 21;
1081
1082 const int c_count = m * n;
1083 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1084 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1085 ldc);
1086 const OffsetColMap result_offset(&c_offset[0], m);
1087 const OffsetColMap result_mult_int(&c_mult_int[0], m);
1088 const int result_shift = c_shift;
1089
1090 GemmContext gemm_context;
1091 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1092 result_offset, result_mult_int, result_shift);
1093 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1094 DefaultL8R8BitDepthParams>(
1095 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1096 output_pipeline);
1097
1098 ResultStats stats;
1099 GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1100
1101 ResultStatsBounds bounds;
1102 const bool good = CheckResultStatsBounds(stats, bounds);
1103 printf("TestMultithreadedPerChannelQuantization: %s\n",
1104 good ? "PASS" : "FAIL");
1105 ReportResultStats(stats, bounds);
1106 Check(good);
1107 }
1108
1109 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1110 void TestWithSmallData() {
1111 const int m = 4;
1112 const int n = 2;
1113 const int k = 3;
1114 // Matrix A (LHS) is:
1115 // | 7 | 10 | 13 | 16 |
1116 // | 8 | 11 | 14 | 17 |
1117 // | 9 | 12 | 15 | 18 |
1118 const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1119 // Matrix B (RHS) is:
1120 // | 1 | 3 | 5 |
1121 // | 2 | 4 | 6 |
1122 const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1123 // Here are the results we expect, from hand calculations:
1124 // (1 * 7) + (3 * 8) + (5 * 9) = 76
1125 // (2 * 7) + (4 * 8) + (6 * 9) = 100
1126 // (1 * 10) + (3 * 11) + (5 * 12) = 103
1127 // (2 * 10) + (4 * 11) + (6 * 12) = 136
1128 // (1 * 13) + (3 * 14) + (5 * 15) = 130
1129 // (2 * 13) + (4 * 14) + (6 * 15) = 172
1130 // (1 * 16) + (3 * 17) + (5 * 18) = 157
1131 // (2 * 16) + (4 * 17) + (6 * 18) = 208
1132 // That means matrix C should be:
1133 // | 76 | 103 | 130 | 157 |
1134 // | 100 | 136 | 172 | 208 |
1135 const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1136
1137 const int c_count = m * n;
1138 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1139
1140 const bool is_a_transposed = true;
1141 const bool is_b_transposed = true;
1142 const bool is_c_transposed = true;
1143 const int lda = k;
1144 const int ldb = n;
1145 const int ldc = n;
1146
1147 const int a_offset = 0;
1148 const int b_offset = 0;
1149 const int c_offset = 0;
1150 const int c_mult = 1;
1151 const int c_shift = 0;
1152
1153 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1154 is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1155 a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1156 c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1157
1158 ResultStats stats;
1159 GetResultStats(output_data.get(), expected_data, c_count, &stats);
1160
1161 ResultStatsBounds bounds;
1162 const bool good = CheckResultStatsBounds(stats, bounds);
1163 printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1164 ReportResultStats(stats, bounds);
1165 Check(good);
1166 }
1167
1168 // This is the most realistic test of how we'll be using the low-precision GEMM
1169 // function in applications. It takes in large input matrices that have been
1170 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1171 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1172 int tolerance_median, int tolerance_max) {
1173 std::unique_ptr<std::uint8_t[]> output_data(
1174 new std::uint8_t[test_data::c_count]);
1175 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1176 test_data::is_a_transposed, test_data::is_b_transposed,
1177 test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1178 test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1179 test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1180 test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1181
1182 ResultStats stats;
1183 GetResultStats(output_data.get(), test_data::expected_c_data,
1184 test_data::c_count, &stats);
1185
1186 ResultStatsBounds bounds;
1187 if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1188 bounds.med_unsigned_diff = tolerance_median;
1189 bounds.max_unsigned_diff = tolerance_max;
1190 bounds.med_signed_diff = 0;
1191 bounds.mean_signed_diff = 0.2f;
1192 }
1193
1194 const bool good = CheckResultStatsBounds(stats, bounds);
1195 printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1196 GetBitDepthName(BitDepth));
1197 ReportResultStats(stats, bounds);
1198 Check(good);
1199 }
1200
1201 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1202 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1203 int result_mult_int, int result_shift) {
1204 Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1205 Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1206 Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1207 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1208 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1209 const int lhs_offset = 12;
1210 const int rhs_offset = -34;
1211
1212 // Test an empty pipeline, i.e. returning raw int32 accumulators.
1213 auto empty_pipeline = std::make_tuple();
1214 GemmContext context;
1215 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1216 &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1217 rhs_offset, empty_pipeline);
1218
1219 for (int r = 0; r < rows; r++) {
1220 for (int c = 0; c < cols; c++) {
1221 std::int32_t expected = 0;
1222 for (int d = 0; d < depth; d++) {
1223 std::int32_t lhs_val =
1224 static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1225 std::int32_t rhs_val =
1226 static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1227 expected += lhs_val * rhs_val;
1228 }
1229 Check(expected == result_raw_int32(r, c));
1230 }
1231 }
1232
1233 // Test a pipeline with only the quantize-down stage, still returning
1234 // unclamped (but scaled) int32's
1235 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1236 quantize_down_stage.result_offset = result_offset;
1237 quantize_down_stage.result_mult_int = result_mult_int;
1238 quantize_down_stage.result_shift = result_shift;
1239 auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1240 Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1241 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1242 &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1243 lhs_offset, rhs_offset, quantize_down_pipeline);
1244
1245 std::int64_t sum = 0;
1246 for (int r = 0; r < rows; r++) {
1247 for (int c = 0; c < cols; c++) {
1248 std::int32_t raw = result_raw_int32(r, c);
1249 std::int32_t expected = RoundingDivideByPOT(
1250 (raw + result_offset) * result_mult_int, result_shift);
1251 Check(expected == result_quantized_down_int32(r, c));
1252 sum += expected;
1253 }
1254 }
1255 std::int64_t avg = sum / (rows * cols);
1256 // Test that the average quantized-down value falls reasonably in the
1257 // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1258 // adjusted.
1259 Check(avg >= 64 && avg <= 192);
1260
1261 // Test the familiar default pipeline consisting of quantize-down and
1262 // clamp-and-cast-to-uint8.
1263 OutputStageSaturatingCastToUint8 saturating_cast_stage;
1264 auto quantize_down_and_saturating_cast_pipeline =
1265 std::make_tuple(quantize_down_stage, saturating_cast_stage);
1266 Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1267 cols);
1268 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1269 &context, lhs.const_map(), rhs.const_map(),
1270 &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1271 quantize_down_and_saturating_cast_pipeline);
1272
1273 for (int r = 0; r < rows; r++) {
1274 for (int c = 0; c < cols; c++) {
1275 std::int32_t quantized = result_quantized_down_int32(r, c);
1276 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1277 Check(expected == result_quantized_down_saturated_uint8(r, c));
1278 }
1279 }
1280
1281 // Test a bias-addition with row-vector
1282 std::vector<std::int32_t> row_vector_data(cols);
1283 std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1284 500);
1285 for (int i = 0; i < cols; i++) {
1286 row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1287 }
1288 typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1289 RowVectorMap row_vector_map(row_vector_data.data(), cols);
1290 OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1291 row_bias_addition_stage.bias_vector = row_vector_map;
1292 auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1293 Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1294 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1295 &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1296 lhs_offset, rhs_offset, row_bias_addition_pipeline);
1297 for (int r = 0; r < rows; r++) {
1298 for (int c = 0; c < cols; c++) {
1299 std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1300 Check(expected == result_of_row_bias_addition(r, c));
1301 }
1302 }
1303
1304 // Test a bias-addition with column-vector
1305 std::vector<std::int32_t> col_vector_data(rows);
1306 for (int i = 0; i < rows; i++) {
1307 col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1308 }
1309 typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1310 ColVectorMap col_vector_map(col_vector_data.data(), rows);
1311 OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1312 col_bias_addition_stage.bias_vector = col_vector_map;
1313 auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1314 Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1315 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1316 &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1317 lhs_offset, rhs_offset, col_bias_addition_pipeline);
1318 for (int r = 0; r < rows; r++) {
1319 for (int c = 0; c < cols; c++) {
1320 std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1321 Check(expected == result_of_col_bias_addition(r, c));
1322 }
1323 }
1324
1325 // Test a clamp
1326 OutputStageClamp clamp_stage;
1327 // Determine min and max of raw int32 accumulators
1328 std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1329 std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1330 for (int r = 0; r < rows; r++) {
1331 for (int c = 0; c < cols; c++) {
1332 raw_min = std::min(raw_min, result_raw_int32(r, c));
1333 raw_max = std::max(raw_max, result_raw_int32(r, c));
1334 }
1335 }
1336 // Pick some interesting clamp min/max bounds
1337 clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1338 clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1339 assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1340 clamp_stage.max <= raw_max);
1341 auto clamp_pipeline = std::make_tuple(clamp_stage);
1342 Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1343 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1344 &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1345 rhs_offset, clamp_pipeline);
1346 for (int r = 0; r < rows; r++) {
1347 for (int c = 0; c < cols; c++) {
1348 std::int32_t raw = result_raw_int32(r, c);
1349 std::int32_t expected =
1350 std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1351 Check(expected == result_clamped(r, c));
1352 }
1353 }
1354
1355 // Test tanh
1356 OutputStageTanh tanh_stage;
1357 const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1358 const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1359 tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1360 tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1361 auto tanh_pipeline = std::make_tuple(tanh_stage);
1362 Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1363 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1364 &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1365 rhs_offset, tanh_pipeline);
1366 for (int r = 0; r < rows; r++) {
1367 for (int c = 0; c < cols; c++) {
1368 std::int32_t raw = result_raw_int32(r, c);
1369 double real_input =
1370 double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1371 double expected = std::tanh(real_input);
1372 std::int32_t actual_int32 = result_tanh(r, c);
1373 double actual =
1374 double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1375 Check(std::abs(expected - actual) < 2e-4);
1376 }
1377 }
1378
1379 // Test a pipeline with bias and clamp
1380 auto bias_clamp_pipeline =
1381 std::make_tuple(col_bias_addition_stage, clamp_stage);
1382 Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1383 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1384 &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1385 lhs_offset, rhs_offset, bias_clamp_pipeline);
1386 for (int r = 0; r < rows; r++) {
1387 for (int c = 0; c < cols; c++) {
1388 std::int32_t raw = result_raw_int32(r, c);
1389 std::int32_t biased = raw + col_vector_data[r];
1390 std::int32_t expected =
1391 std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1392 Check(expected == result_biased_clamped(r, c));
1393 }
1394 }
1395
1396 // Test a full pipeline with bias and clamp and quantization down to 8bit
1397 // result
1398 auto bias_clamp_quantize_cast_pipeline =
1399 std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1400 saturating_cast_stage);
1401 Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1402 rows, cols);
1403 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1404 &context, lhs.const_map(), rhs.const_map(),
1405 &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1406 bias_clamp_quantize_cast_pipeline);
1407 for (int r = 0; r < rows; r++) {
1408 for (int c = 0; c < cols; c++) {
1409 std::int32_t quantized = RoundingDivideByPOT(
1410 (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1411 result_shift);
1412 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1413 Check(expected == result_biased_clamped_quantized_casted(r, c));
1414 }
1415 }
1416
1417 // Test a pipeline with the fixed-point-multiplier variant stage for the
1418 // quantizing down of 32bit accumulators.
1419 //
1420 // First, figure appropriate fixedpoint multiplier and shift values.
1421 std::int32_t result_fixedpoint_multiplier = result_mult_int;
1422 std::int32_t result_fixedpoint_shift = result_shift;
1423 Check(result_mult_int > 0);
1424 Check(result_shift > 0);
1425 result_fixedpoint_multiplier = result_mult_int;
1426 result_fixedpoint_shift = result_shift - 31;
1427 while (result_fixedpoint_multiplier < (1 << 30)) {
1428 result_fixedpoint_multiplier <<= 1;
1429 result_fixedpoint_shift++;
1430 }
1431 Check(result_fixedpoint_shift >= 0);
1432 // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1433 OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1434 quantize_down_by_fixedpoint_stage;
1435 quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1436 static_cast<std::int32_t>(
1437 round(static_cast<double>(result_offset * result_mult_int) /
1438 (1 << result_shift)));
1439 quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1440 result_fixedpoint_multiplier;
1441 quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1442 auto quantize_down_by_fixedpoint_pipeline =
1443 std::make_tuple(quantize_down_by_fixedpoint_stage);
1444 Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1445 rows, cols);
1446 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1447 &context, lhs.const_map(), rhs.const_map(),
1448 &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1449 quantize_down_by_fixedpoint_pipeline);
1450
1451 std::vector<std::int32_t> diffs_caused_by_fixedpoint;
1452 for (int r = 0; r < rows; r++) {
1453 for (int c = 0; c < cols; c++) {
1454 const std::int32_t actual =
1455 result_quantized_down_by_fixedpoint_int32(r, c);
1456 const std::int32_t raw = result_raw_int32(r, c);
1457 const std::int32_t expected =
1458 quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1459 RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1460 raw, result_fixedpoint_multiplier),
1461 result_fixedpoint_shift);
1462 Check(actual == expected);
1463 }
1464 }
1465
1466 // Test the variant of the familiar default pipeline consisting of
1467 // quantize-down and
1468 // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1469 // downscaling.
1470 auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1471 std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1472 Matrix<std::uint8_t, ResultOrder>
1473 result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1474 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1475 &context, lhs.const_map(), rhs.const_map(),
1476 &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1477 rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1478
1479 for (int r = 0; r < rows; r++) {
1480 for (int c = 0; c < cols; c++) {
1481 std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1482 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1483 Check(expected ==
1484 result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1485 }
1486 }
1487
1488 printf("TestOutputStages: PASS with ResultOrder=%s\n",
1489 OrderName(ResultOrder));
1490 }
1491
1492 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1493 template <typename BitDepthParams>
TestExhaustively()1494 void TestExhaustively() {
1495 GemmContext context;
1496
1497 // Test the internal GEMM interfaces
1498 test_gemm<
1499 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1500 std::uint8_t, BitDepthParams>>(&context);
1501
1502 test_gemm<
1503 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1504 std::uint8_t, BitDepthParams>>(&context);
1505
1506 // Test the public GEMM interfaces
1507 test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1508
1509 // Test GEMV cases (internal interfaces)
1510 test_gemv<
1511 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1512 std::uint8_t, BitDepthParams>>(&context);
1513
1514 test_gemv<
1515 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1516 std::uint8_t, BitDepthParams>>(&context);
1517
1518 // Test GEMV cases (public interfaces)
1519 test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1520 }
1521
1522 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1523 void TestExhaustivelyEightBitIntGemm() {
1524 GemmContext context;
1525 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1526 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1527 test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1528 }
1529
TestKernels()1530 void TestKernels() {
1531 GemmContext context;
1532
1533 // Test specific kernels with various different formats,
1534 // to exercises corner cases especially in the packing code.
1535 test_gemm_kernel<
1536 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1537 KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1538 &context);
1539
1540 test_gemm_kernel<
1541 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1542 KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1543 &context);
1544
1545 test_gemm_kernel<
1546 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1547 KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1548 &context);
1549
1550 test_gemm_kernel<ReferenceKernel<KernelFormat<
1551 KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1552 KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1553
1554 test_gemm_kernel<ReferenceKernel<KernelFormat<
1555 KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1556 KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1557
1558 test_gemm_kernel<ReferenceKernel<KernelFormat<
1559 KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1560 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1561
1562 test_gemm_kernel<ReferenceKernel<KernelFormat<
1563 KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1564 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1565
1566 test_gemm_kernel<ReferenceKernel<KernelFormat<
1567 KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1568 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1569
1570 test_gemm_kernel<ReferenceKernel<KernelFormat<
1571 KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1572 KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1573 }
1574
1575 #endif // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1576
1577 template <typename BitDepthParams>
TestOutputStages()1578 void TestOutputStages() {
1579 // Test non-default output pipelines with various combinations of
1580 // output stages.
1581 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1582 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1583 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1584 14);
1585 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1586 14);
1587 }
1588
test()1589 void test() {
1590 #ifdef GEMMLOWP_TEST_PROFILE
1591 RegisterCurrentThreadForProfiling();
1592 StartProfiling();
1593 #endif
1594
1595 // Run a first quick test against hand-calculated data.
1596 TestWithSmallData();
1597
1598 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1599 TestExhaustively<DefaultL8R8BitDepthParams>();
1600 TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1601 TestExhaustively<DefaultL7R5BitDepthParams>(); // legacy, same as L8R8
1602 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1603 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1604 TestKernels();
1605 #endif
1606
1607 // Run against actual data from a network evaluation.
1608 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1609 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1610
1611 // Test non-default output pipelines with various combinations of
1612 // output stages.
1613 TestOutputStages<DefaultL8R8BitDepthParams>();
1614 TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1615
1616 // Test per channel quantization.
1617 TestWithSmallDataPerChannelQuantization();
1618 TestWithLargeDataPerChannelQuantization();
1619 TestMultithreadedPerChannelQuantization();
1620 #ifdef GEMMLOWP_TEST_PROFILE
1621 FinishProfiling();
1622 #endif
1623
1624 std::cerr << "All tests passed." << std::endl;
1625
1626 // We have been testing the eight_bit_int_gemm, so we should free its
1627 // persistent
1628 // resources now to avoid having leak-checking tools report leaks.
1629 eight_bit_int_gemm::FreePersistentResources();
1630 }
1631
1632 } // end namespace gemmlowp
1633
1634 // For iOS, we need to define our own main(), so skip it here.
1635 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1636 int main() { gemmlowp::test(); }
1637 #endif
1638