1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test.h"
16
17 #include <array>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32
33 namespace gemmlowp {
34
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36 bool transpose_c, int m, int n, int k,
37 const std::uint8_t* a, std::int32_t a_offset,
38 int lda, const std::uint8_t* b,
39 std::int32_t b_offset, int ldb, std::uint8_t* c,
40 std::int32_t c_offset, std::int32_t c_mult_int,
41 std::int32_t c_shift, int ldc) {
42 ScopedProfilingLabel("ReferenceEightBitIntGemm");
43 assert((c_shift >= 0) && (c_shift <= 32));
44
45 assert(a != nullptr);
46 assert(b != nullptr);
47 assert(c != nullptr);
48
49 int a_i_stride;
50 int a_l_stride;
51 if (transpose_a) {
52 a_i_stride = lda;
53 a_l_stride = 1;
54 } else {
55 a_i_stride = 1;
56 a_l_stride = lda;
57 }
58 int b_j_stride;
59 int b_l_stride;
60 if (transpose_b) {
61 b_j_stride = 1;
62 b_l_stride = ldb;
63 } else {
64 b_j_stride = ldb;
65 b_l_stride = 1;
66 }
67 int c_i_stride;
68 int c_j_stride;
69 if (transpose_c) {
70 c_i_stride = ldc;
71 c_j_stride = 1;
72 } else {
73 c_i_stride = 1;
74 c_j_stride = ldc;
75 }
76 int i, j, l;
77
78 const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
79
80 for (j = 0; j < n; j++) {
81 for (i = 0; i < m; i++) {
82 std::int32_t total = 0;
83 for (l = 0; l < k; l++) {
84 const int a_index = i * a_i_stride + l * a_l_stride;
85 const std::uint8_t a_as_byte = a[a_index];
86 const std::int32_t a_as_int =
87 static_cast<std::int32_t>(a_as_byte) + a_offset;
88 const int b_index = j * b_j_stride + l * b_l_stride;
89 const std::uint8_t b_as_byte = b[b_index];
90 const std::int32_t b_as_int =
91 static_cast<std::int32_t>(b_as_byte) + b_offset;
92 const std::int32_t mult_as_int = a_as_int * b_as_int;
93 total += mult_as_int;
94 }
95 std::int32_t output =
96 (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
97 if (output > 255) {
98 output = 255;
99 }
100 if (output < 0) {
101 output = 0;
102 }
103 const int c_index = i * c_i_stride + j * c_j_stride;
104 c[c_index] = static_cast<std::uint8_t>(output);
105 }
106 }
107 }
108
109 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
110 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
111 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
112 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
113
114 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
115 // interface, so we can use the same testing code to test all of them
116
117 template <typename Kernel, typename Scalar, typename tBitDepthParams>
118 struct SingleThreadGemmWrapper {
119 typedef tBitDepthParams BitDepthParams;
120
Namegemmlowp::SingleThreadGemmWrapper121 static const char* Name() {
122 static char buf[256];
123 snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
124 return buf;
125 }
126
127 typedef SingleThreadGemmContext Context;
128
129 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper130 static bool Gemm(Context* context,
131 const MatrixMap<const Scalar, LhsOrder>& lhs,
132 const MatrixMap<const Scalar, RhsOrder>& rhs,
133 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
134 int rhs_offset, int result_offset, int result_mult_int,
135 int result_shift) {
136 ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
137 const int rows = lhs.rows();
138 const int cols = rhs.cols();
139 if (rows < cols) {
140 // SingleThreadGemm is never called with rows < cols.
141 // That case is handled earlier.
142 return false;
143 }
144 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
145 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
146 SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
147 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
148 OffsetRowDup>(
149 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
150 rhs_offset_vector,
151 MakeStandardOutputPipeline(result_offset, result_mult_int,
152 result_shift));
153 return true;
154 }
155 };
156
157 template <typename Kernel, typename Scalar, typename tBitDepthParams>
158 struct MultiThreadGemmWrapper {
159 typedef tBitDepthParams BitDepthParams;
160
Namegemmlowp::MultiThreadGemmWrapper161 static const char* Name() {
162 static char buf[256];
163 snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
164 return buf;
165 }
166
167 typedef MultiThreadGemmContext Context;
168
169 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper170 static bool Gemm(Context* context,
171 const MatrixMap<const Scalar, LhsOrder>& lhs,
172 const MatrixMap<const Scalar, RhsOrder>& rhs,
173 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
174 int rhs_offset, int result_offset, int result_mult_int,
175 int result_shift) {
176 ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
177 context->set_max_num_threads(0);
178 const int rows = lhs.rows();
179 const int cols = rhs.cols();
180 if (rows < cols) {
181 // SingleThreadGemm is never called with rows < cols.
182 // That case is handled earlier.
183 return false;
184 }
185 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
186 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
187 MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
188 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
189 OffsetRowDup>(
190 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
191 rhs_offset_vector,
192 MakeStandardOutputPipeline(result_offset, result_mult_int,
193 result_shift));
194 return true;
195 }
196 };
197
198 template <typename Scalar, typename tBitDepthParams>
199 struct PublicGemmWrapper {
200 typedef tBitDepthParams BitDepthParams;
201
Namegemmlowp::PublicGemmWrapper202 static const char* Name() { return "public Gemm"; }
203
204 typedef GemmContext Context;
205
206 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper207 static bool Gemm(Context* context,
208 const MatrixMap<const Scalar, LhsOrder>& lhs,
209 const MatrixMap<const Scalar, RhsOrder>& rhs,
210 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
211 int rhs_offset, int result_offset, int result_mult_int,
212 int result_shift) {
213 ScopedProfilingLabel("PublicGemmWrapper::Gemm");
214 gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
215 ResultOrder>(context, lhs, rhs, result, lhs_offset,
216 rhs_offset, result_offset, result_mult_int,
217 result_shift);
218 return true;
219 }
220 };
221
222 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
223 struct BitDepthParamsForSettings {};
224
225 template <>
226 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
227 : DefaultL8R8BitDepthParams {};
228
229 template <>
230 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
231 : DefaultL7R5BitDepthParams {};
232
233 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
234 struct EightBitIntGemmWrapper {
235 typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
236
Namegemmlowp::EightBitIntGemmWrapper237 static const char* Name() { return "EightBitIntGemm"; }
238
239 typedef void Context;
240
241 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper242 static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
243 const MatrixMap<const Scalar, RhsOrder>& rhs,
244 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
245 int rhs_offset, int result_offset, int result_mult_int,
246 int result_shift) {
247 ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
248 const bool transpose_c = ResultOrder == MapOrder::RowMajor;
249 const bool transpose_a = LhsOrder == MapOrder::RowMajor;
250 const bool transpose_b = RhsOrder == MapOrder::RowMajor;
251 eight_bit_int_gemm::EightBitIntGemm(
252 transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
253 lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
254 rhs_offset, rhs.stride(), result->data(), result_offset,
255 result_mult_int, result_shift, result->stride(), BitDepth);
256 return true;
257 }
258 };
259
260 template <typename Scalar>
261 struct ReferenceEightBitIntGemmWrapper {
262 typedef DefaultL8R8BitDepthParams BitDepthParams;
263
Namegemmlowp::ReferenceEightBitIntGemmWrapper264 static const char* Name() { return "ReferenceEightBitIntGemm"; }
265
266 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper267 static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
268 const MatrixMap<const Scalar, LhsOrder>& lhs,
269 const MatrixMap<const Scalar, RhsOrder>& rhs,
270 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
271 int rhs_offset, int result_offset, int result_mult_int,
272 int result_shift) {
273 ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
274 ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
275 rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
276 lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
277 result->data(), result_offset, result_mult_int,
278 result_shift, result->stride());
279 return true;
280 }
281 };
282
OrderName(MapOrder order)283 const char* OrderName(MapOrder order) {
284 return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
285 }
286
287 struct ResultStats {
ResultStatsgemmlowp::ResultStats288 ResultStats()
289 : count(0),
290 med_val(0),
291 mean_signed_diff(0),
292 med_signed_diff(0),
293 med_unsigned_diff(0),
294 max_unsigned_diff(0) {}
295
296 int count;
297 int med_val;
298 float mean_signed_diff;
299 int med_signed_diff;
300 int med_unsigned_diff;
301 int max_unsigned_diff;
302
303 std::vector<int> count_diff_by_pot_slice;
304 };
305
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)306 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
307 size_t count, ResultStats* stats) {
308 ScopedProfilingLabel("GetResultStats");
309 std::vector<std::uint8_t> results;
310 std::vector<std::int16_t> signed_diffs;
311 std::vector<std::uint8_t> unsigned_diffs;
312 std::int64_t signed_diffs_sum = 0;
313 for (size_t i = 0; i < count; i++) {
314 results.push_back(actual[i]);
315 std::int16_t signed_diff = actual[i] - expected[i];
316 signed_diffs.push_back(signed_diff);
317 unsigned_diffs.push_back(std::abs(signed_diff));
318 signed_diffs_sum += signed_diff;
319 }
320
321 std::sort(results.begin(), results.end());
322 std::sort(signed_diffs.begin(), signed_diffs.end());
323 std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
324
325 const size_t middle = count / 2;
326
327 stats->count = count;
328 stats->med_val = results[middle];
329 stats->mean_signed_diff = float(signed_diffs_sum) / count;
330 stats->med_signed_diff = signed_diffs[middle];
331 stats->med_unsigned_diff = unsigned_diffs[middle];
332 stats->max_unsigned_diff = unsigned_diffs.back();
333
334 // Size 9 for 9 different POT values: 2^0, ..., 2^8
335 stats->count_diff_by_pot_slice.resize(9);
336 auto cur = unsigned_diffs.begin();
337 size_t checksum = 0;
338 for (int exponent = 0; exponent < 9; exponent++) {
339 int pot = 1 << exponent;
340 auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
341 checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
342 cur = next;
343 }
344 assert(checksum == count);
345 }
346
347 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds348 ResultStatsBounds()
349 : mean_signed_diff(0),
350 med_signed_diff(0),
351 med_unsigned_diff(0),
352 max_unsigned_diff(0) {}
353
354 float mean_signed_diff;
355 int med_signed_diff;
356 int med_unsigned_diff;
357 int max_unsigned_diff;
358 };
359
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)360 bool CheckResultStatsBounds(const ResultStats& stats,
361 const ResultStatsBounds& bounds) {
362 return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
363 stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
364 std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
365 std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
366 }
367
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)368 void ReportResultStats(const ResultStats& stats,
369 const ResultStatsBounds& bounds) {
370 printf(" number of matrix entries: %d\n", stats.count);
371 printf(" median value: %d\n", stats.med_val);
372 printf(" median unsigned diff: %d (tolerating %d)\n",
373 stats.med_unsigned_diff, bounds.med_unsigned_diff);
374 printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
375 bounds.max_unsigned_diff);
376 printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
377 bounds.med_signed_diff);
378 printf(" mean signed diff: %.3g (tolerating %.3g)\n",
379 stats.mean_signed_diff, bounds.mean_signed_diff);
380
381 printf("No error: %.2f %% of entries\n",
382 100.f * stats.count_diff_by_pot_slice[0] / stats.count);
383 for (int exponent = 1; exponent < 9; exponent++) {
384 printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
385 (1 << exponent) - 1,
386 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
387 }
388 }
389
390 // Our approach to choosing result_shift values for testing, is bisection.
391 // This function takes an interval, [result_shift_min .. result_shift_max].
392 // If too much saturation occurred in either direction, it bisects accordingly,
393 // recursing until the interval contains only one value.
394 // The primary reason why we prefer this over computing optimal shift values,
395 // is that we actually want to exercise some saturation, as there is nontrivial
396 // code handling that in gemmlowp.
397 // Secondarily, this is faster than computing optimal shifts, since in 90% of
398 // cases the first-tried shift value 16 turns out to be good enough.
399 template <typename GemmWrapper, typename LhsType, typename RhsType,
400 typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)401 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
402 const RhsType& rhs, ResultType* result, int lhs_offset,
403 int rhs_offset, int result_offset, int result_mult_int,
404 int result_shift_min, int result_shift_max) {
405 const int rows = lhs.rows();
406 const int cols = rhs.cols();
407 Check(lhs.cols() == rhs.rows());
408 const int depth = lhs.cols();
409
410 const int result_shift = (result_shift_min + result_shift_max) / 2;
411
412 if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
413 &result->map(), lhs_offset, rhs_offset, result_offset,
414 result_mult_int, result_shift)) {
415 // Internal GEMM functions are not required to handle all cases
416 // (e.g. rows < cols) as these are supposed to have been handled
417 // ahead of them. Their test wrappers return false in that case.
418 return;
419 }
420
421 typedef typename ResultType::Scalar Scalar;
422 static const MapOrder kLhsOrder = LhsType::kOrder;
423 static const MapOrder kRhsOrder = RhsType::kOrder;
424 static const MapOrder kResultOrder = ResultType::kOrder;
425 ResultType ref_result(rows, cols);
426 const bool transpose_c = kResultOrder == MapOrder::RowMajor;
427 const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
428 const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
429 ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
430 transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
431 &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
432 result_shift);
433
434 typedef typename GemmWrapper::BitDepthParams BitDepthParams;
435
436 ResultStats stats;
437 GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
438
439 // Adjust shifts until we get meaningful results
440 int new_result_shift_min = result_shift_min;
441 int new_result_shift_max = result_shift_max;
442 bool retry = false;
443
444 if (stats.med_val < 32) {
445 new_result_shift_max = (result_shift_min + result_shift_max) / 2;
446 retry = true;
447 }
448
449 if (stats.med_val > 224) {
450 new_result_shift_min = (result_shift_min + result_shift_max) / 2;
451 retry = true;
452 }
453
454 if (retry) {
455 if (result_shift_min != result_shift_max) {
456 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
457 rhs_offset, result_offset, result_mult_int,
458 new_result_shift_min, new_result_shift_max);
459 }
460 return;
461 }
462
463 ResultStatsBounds bounds;
464
465 // Check results
466 const bool good = CheckResultStatsBounds(stats, bounds);
467
468 printf(
469 "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
470 good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
471 OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
472 lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
473
474 if (!good) {
475 ReportResultStats(stats, bounds);
476
477 int bad_coeffs_printed = 0;
478 for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
479 for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
480 if (ref_result(r, c) != (*result)(r, c)) {
481 printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
482 ref_result(r, c), (*result)(r, c));
483 bad_coeffs_printed++;
484 }
485 }
486 }
487 }
488
489 Check(good);
490 }
491
492 template <typename GemmWrapper, typename LhsType, typename RhsType,
493 typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)494 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
495 const RhsType& rhs, ResultType* result, int lhs_offset,
496 int rhs_offset, int result_offset, int result_mult_int) {
497 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
498 result_offset, result_mult_int, 0, 32);
499 }
500
501 enum class WhatParamsToTest {
502 All,
503 OnlyGenericCase,
504 };
505
506 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
507 MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)508 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
509 int cols, WhatParamsToTest params_to_test) {
510 typedef std::uint8_t Scalar;
511 typedef Matrix<Scalar, LhsOrder> LhsType;
512 using BitDepthParams = typename GemmWrapper::BitDepthParams;
513 LhsType lhs(rows, depth);
514 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
515 typedef Matrix<Scalar, RhsOrder> RhsType;
516 RhsType rhs(depth, cols);
517 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
518 typedef Matrix<Scalar, ResultOrder> ResultType;
519 ResultType result(rows, cols);
520 MakeZero(&result);
521
522 if (params_to_test == WhatParamsToTest::All) {
523 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
524 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
525 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
526 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
527 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
528 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
529 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
530 }
531 test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
532 }
533
534 enum class WhatOrdersToTest { All, OnlyRCC };
535
536 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)537 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
538 int cols, WhatParamsToTest params_to_test,
539 WhatOrdersToTest orders_to_test) {
540 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \
541 do { \
542 test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
543 MapOrder::ResultOrder>(context, rows, depth, cols, \
544 params_to_test); \
545 } while (false)
546
547 if (orders_to_test == WhatOrdersToTest::All) {
548 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
549 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
550 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
551 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
552
553 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
554 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
555 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
556 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
557 } else {
558 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
559 }
560
561 #undef GEMMLOWP_ONE_TEST
562 }
563
564 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)565 void test_gemm_kernel(MultiThreadGemmContext* context) {
566 typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
567 DefaultL8R8BitDepthParams>
568 GemmWrapper;
569 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
570 WhatOrdersToTest::OnlyRCC);
571 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
572 WhatOrdersToTest::OnlyRCC);
573 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
574 WhatOrdersToTest::OnlyRCC);
575 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
576 WhatOrdersToTest::OnlyRCC);
577 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
578 WhatOrdersToTest::OnlyRCC);
579 test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
580 WhatOrdersToTest::OnlyRCC);
581 test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
582 WhatOrdersToTest::OnlyRCC);
583 test_gemm<GemmWrapper>(context, 200, 200, 200,
584 WhatParamsToTest::OnlyGenericCase,
585 WhatOrdersToTest::All);
586 test_gemm<GemmWrapper>(context, 50, 5000, 50,
587 WhatParamsToTest::OnlyGenericCase,
588 WhatOrdersToTest::OnlyRCC);
589 }
590
591 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)592 void test_gemm(typename GemmWrapper::Context* context) {
593 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
594 WhatOrdersToTest::OnlyRCC);
595 test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
596 WhatOrdersToTest::OnlyRCC);
597 test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
598 WhatOrdersToTest::OnlyRCC);
599 test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
600 WhatOrdersToTest::OnlyRCC);
601 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
602 WhatOrdersToTest::OnlyRCC);
603 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
604 WhatOrdersToTest::OnlyRCC);
605 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
606 WhatOrdersToTest::OnlyRCC);
607 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
608 WhatOrdersToTest::OnlyRCC);
609 test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
610 WhatOrdersToTest::OnlyRCC);
611 test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
612 WhatOrdersToTest::OnlyRCC);
613 test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
614 WhatOrdersToTest::OnlyRCC);
615 test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
616 WhatOrdersToTest::OnlyRCC);
617 test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
618 WhatOrdersToTest::All);
619 test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
620 WhatOrdersToTest::OnlyRCC);
621 test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
622 WhatOrdersToTest::OnlyRCC);
623 test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
624 WhatOrdersToTest::All);
625 test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
626 WhatOrdersToTest::OnlyRCC);
627
628 test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
629 WhatOrdersToTest::OnlyRCC);
630 test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
631 WhatOrdersToTest::OnlyRCC);
632 test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
633 WhatOrdersToTest::OnlyRCC);
634 test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
635 WhatOrdersToTest::OnlyRCC);
636 test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
637 WhatOrdersToTest::OnlyRCC);
638 test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
639 WhatOrdersToTest::OnlyRCC);
640
641 test_gemm<GemmWrapper>(context, 512, 512, 512,
642 WhatParamsToTest::OnlyGenericCase,
643 WhatOrdersToTest::OnlyRCC);
644 test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
645 WhatParamsToTest::OnlyGenericCase,
646 WhatOrdersToTest::OnlyRCC);
647 test_gemm<GemmWrapper>(context, 567, 2345, 123,
648 WhatParamsToTest::OnlyGenericCase,
649 WhatOrdersToTest::OnlyRCC);
650 test_gemm<GemmWrapper>(context, 100, 5000, 100,
651 WhatParamsToTest::OnlyGenericCase,
652 WhatOrdersToTest::OnlyRCC);
653 test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
654 WhatOrdersToTest::OnlyRCC);
655 test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
656 WhatOrdersToTest::OnlyRCC);
657 test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
658 WhatOrdersToTest::OnlyRCC);
659 test_gemm<GemmWrapper>(context, 1, 1000, 1000,
660 WhatParamsToTest::OnlyGenericCase,
661 WhatOrdersToTest::OnlyRCC);
662 test_gemm<GemmWrapper>(context, 1000, 1, 1000,
663 WhatParamsToTest::OnlyGenericCase,
664 WhatOrdersToTest::OnlyRCC);
665 test_gemm<GemmWrapper>(context, 1000, 1000, 1,
666 WhatParamsToTest::OnlyGenericCase,
667 WhatOrdersToTest::OnlyRCC);
668 test_gemm<GemmWrapper>(context, 777, 3456, 1,
669 WhatParamsToTest::OnlyGenericCase,
670 WhatOrdersToTest::OnlyRCC);
671 test_gemm<GemmWrapper>(context, 4567, 555, 1,
672 WhatParamsToTest::OnlyGenericCase,
673 WhatOrdersToTest::OnlyRCC);
674
675 // Test all storage orders
676 test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
677 WhatOrdersToTest::All);
678 test_gemm<GemmWrapper>(context, 300, 400, 500,
679 WhatParamsToTest::OnlyGenericCase,
680 WhatOrdersToTest::All);
681 }
682
683 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)684 void test_gemv(typename GemmWrapper::Context* context) {
685 test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
686 WhatOrdersToTest::OnlyRCC);
687 test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
688 WhatOrdersToTest::OnlyRCC);
689 test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
690 WhatOrdersToTest::OnlyRCC);
691 test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
692 WhatOrdersToTest::OnlyRCC);
693 test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
694 WhatOrdersToTest::OnlyRCC);
695 test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
696 WhatOrdersToTest::OnlyRCC);
697 test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
698 WhatOrdersToTest::OnlyRCC);
699 test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
700 WhatOrdersToTest::OnlyRCC);
701 test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
702 WhatOrdersToTest::OnlyRCC);
703 test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
704 WhatOrdersToTest::OnlyRCC);
705 test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
706 WhatOrdersToTest::OnlyRCC);
707 test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
708 WhatOrdersToTest::OnlyRCC);
709
710 // Test all storage orders
711 test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
712 WhatOrdersToTest::All);
713 test_gemm<GemmWrapper>(context, 300, 400, 1,
714 WhatParamsToTest::OnlyGenericCase,
715 WhatOrdersToTest::All);
716 }
717
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)718 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
719 switch (b) {
720 case eight_bit_int_gemm::BitDepthSetting::A8B8:
721 return "Lhs: 8 bit, Rhs: 8 bit";
722 case eight_bit_int_gemm::BitDepthSetting::A5B7:
723 return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
724 default:
725 abort();
726 return nullptr;
727 }
728 }
729
730 // Runs a small set of hand-picked data for per-channel quantized data.
731 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
732 // image.
TestWithSmallDataPerChannelQuantization()733 void TestWithSmallDataPerChannelQuantization() {
734 const int m = 2;
735 const int n = 9;
736 const int k = 12;
737
738 // 12 x 2, columnwise.
739 const std::uint8_t a_data[] = {0, 0, 0, 0, 0, 0, 0, 0,
740 0, 255, 255, 255, 64, 64, 64, 64,
741 64, 64, 0, 0, 0, 255, 255, 255};
742 const int lda = k;
743 int a_offset[] = {0, -64};
744 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
745 const OffsetColMap lhs_offset(a_offset, m);
746
747 // 12 x 9, columnwise.
748 const std::uint8_t b_data[] = {
749 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0,
750 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 127,
751 127, 127, 0, 0, 0, 127, 127, 127, 0, 0, 0, 255, 255, 255,
752 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0,
753 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 0, 0, 0, 127,
754 127, 127, 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127,
755 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127, 0, 0,
756 0, 127, 127, 127, 127, 127, 127, 127, 127, 127};
757 const int ldb = k;
758 int b_offset = -127;
759 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
760 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
761
762 // 2 x 9, columnwise.
763 const std::uint8_t expected_c_data[] = {255, 255, 0, 0, 127, 159,
764 0, 64, 0, 64, 127, 159,
765 127, 127, 127, 127, 127, 127};
766 const int ldc = m;
767 int c_offset[] = {97155, 97346};
768 int c_mult_int[] = {2741, 2741};
769 const int c_shift = 21;
770
771 const int c_count = m * n;
772 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
773 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
774 ldc);
775 const OffsetColMap result_offset(c_offset, m);
776 const OffsetColMap result_mult_int(c_mult_int, m);
777 const int result_shift = c_shift;
778
779 GemmContext gemm_context;
780 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
781 result_offset, result_mult_int, result_shift);
782 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
783 DefaultL8R8BitDepthParams>(
784 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
785 output_pipeline);
786
787 ResultStats stats;
788 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
789
790 ResultStatsBounds bounds;
791 const bool good = CheckResultStatsBounds(stats, bounds);
792 printf("TestWithSmallDataPerChannelQuantization: %s\n",
793 good ? "PASS" : "FAIL");
794 ReportResultStats(stats, bounds);
795 Check(good);
796 }
797
798 // Runs a larger set of hand-picked data for per-channel quantized data.
799 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
800 // image. Right now, I have 7 different filters and 15 copies of the first
801 // filter to make sure NEON code path that processes 16 rows at a time is
802 // covered.
TestWithLargeDataPerChannelQuantization()803 void TestWithLargeDataPerChannelQuantization() {
804 const int m = 22;
805 const int n = 25;
806 const int k = 27;
807
808 // 27 x 22, column-wise.
809 const std::uint8_t a_data[] = {
810 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
811 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
812 0, 0, 0, 0, 0, 0, 127, 127, 127, 255, 255, 255, 127, 127, 127,
813 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127,
814 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
815 127, 127, 127, 0, 0, 0, 51, 51, 51, 51, 51, 51, 51, 51, 51,
816 0, 0, 0, 255, 255, 255, 0, 0, 0, 51, 51, 51, 51, 51, 51,
817 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51, 51, 51, 51,
818 255, 255, 255, 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51,
819 0, 0, 0, 64, 64, 64, 0, 0, 0, 64, 64, 64, 255, 255, 255,
820 64, 64, 64, 0, 0, 0, 64, 64, 64, 0, 0, 0, 36, 36, 36,
821 0, 0, 0, 36, 36, 36, 0, 0, 0, 255, 255, 255, 0, 0, 0,
822 36, 36, 36, 0, 0, 0, 36, 36, 36, 0, 0, 0, 0, 0, 0,
823 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
824 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
825 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
826 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
827 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
828 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
829 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
830 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
831 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
832 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
833 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
834 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
835 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
836 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
837 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
838 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
839 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
840 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
841 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
842 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
843 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
844 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
845 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
846 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
847 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
848 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
849 0, 0, 0, 0, 0, 0, 0, 0, 0,
850 };
851 const int lda = k;
852 int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
853 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
854 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
855 const OffsetColMap lhs_offset(a_offset, m);
856
857 // 27 x 25, column-wise.
858 const std::uint8_t b_data[] = {
859 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
860 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
861 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
863 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
864 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
865 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
866 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
867 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
868 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
869 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
870 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
871 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
873 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
874 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
876 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
877 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
878 119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
879 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
880 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
882 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
883 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
885 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
886 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
887 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
888 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
889 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
890 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
891 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
892 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
894 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
895 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
897 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
898 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
899 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
900 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
901 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
902 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
903 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
904 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
905 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
906 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
907 127, 127, 127};
908 const int ldb = k;
909 int b_offset = -127;
910 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
911 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
912
913 // 22 x 25, column-wise.
914 const std::uint8_t expected_c_data[] = {
915 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
916 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7,
917 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
918 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7,
919 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7,
920 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
921 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
922 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7,
923 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
924 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
925 7, 7, 7, 7, 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7,
926 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87,
927 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
928 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7,
929 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87,
930 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
931 7, 7, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7,
932 7, 7, 7, 7, 7, 7, 7, 7, 7, 255, 135, 135, 255, 255, 143,
933 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
934 255, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7,
935 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91,
936 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
937 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
938 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7,
939 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
940 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 7,
941 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7, 7,
942 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
943 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
944 7, 7, 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7,
945 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
946 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
947 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7,
948 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87,
949 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
950 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7,
951 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 99, 99, 99, 99, 99,
952 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
953 99, 99, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
954 111, 111, 111, 111, 111, 111, 111, 111, 111,
955 };
956 const int ldc = m;
957 int c_offset[] = {
958 6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
959 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477,
960 };
961 int c_mult_int[] = {
962 41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
963 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
964 41121, 41121, 41121, 41121, 41121, 41121,
965 };
966 const int c_shift = 21;
967
968 const int c_count = m * n;
969 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
970 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
971 ldc);
972 const OffsetColMap result_offset(c_offset, m);
973 const OffsetColMap result_mult_int(c_mult_int, m);
974 const int result_shift = c_shift;
975
976 GemmContext gemm_context;
977 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
978 result_offset, result_mult_int, result_shift);
979 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
980 DefaultL8R8BitDepthParams>(
981 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
982 output_pipeline);
983
984 ResultStats stats;
985 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
986
987 ResultStatsBounds bounds;
988 const bool good = CheckResultStatsBounds(stats, bounds);
989 printf("TestWithLargeDataPerChannelQuantization: %s\n",
990 good ? "PASS" : "FAIL");
991 ReportResultStats(stats, bounds);
992 Check(good);
993 }
994
995 // Multithreading only activates when the result has more than 16 rows, and also
996 // (result rows) * (result cols) * depth >= 2 x 65 x 1024. Size was selected
997 // to run in 3 threads.
998 //
999 // Based on the following floating point data:
1000 // LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1001 // 1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1002 // 0.01, 0.02 in last 16 rows.
1003 // RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1004 // Varying boundaries were used for each 16 rows block of LHS, to test for
1005 // correct indexing into offsets.
1006 // Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1007 // 5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1008 // 16 rows.
TestMultithreadedPerChannelQuantization()1009 void TestMultithreadedPerChannelQuantization() {
1010 const int m = 64;
1011 const int n = 20;
1012 const int k = 160;
1013
1014 // LHS, m x k.
1015 const std::array<std::int32_t, 4> lhs_offsets_terse{{
1016 0, -51, -85, -109,
1017 }};
1018 assert(lhs_offsets_terse.size() * 16 == m);
1019 const std::array<std::uint8_t, 4> lhs_first_el{{
1020 128, 153, 170, 182,
1021 }};
1022 assert(lhs_first_el.size() * 16 == m);
1023
1024 // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1025 std::vector<std::uint8_t> a_data(m * k, 0);
1026 for (int i = 0; i < m; ++i) {
1027 a_data[i * k] = lhs_first_el[i / 16];
1028 a_data[i * k + 1] = 255;
1029 for (int j = 2; j < k; ++j) {
1030 a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1031 }
1032 }
1033
1034 const int lda = k;
1035 // Given values at [i / 16].
1036 std::vector<std::int32_t> a_offset(m, 0);
1037 for (int i = 0; i < m; ++i) {
1038 a_offset[i] = lhs_offsets_terse[i / 16];
1039 }
1040
1041 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1042 const OffsetColMap lhs_offset(&a_offset[0], m);
1043
1044 // RHS, k x n.
1045 // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1046 std::vector<std::uint8_t> b_data(k * n, 0);
1047 b_data[0] = 128;
1048 b_data[1] = 255;
1049
1050 const int ldb = k;
1051 std::int32_t b_offset = 0;
1052 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1053 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1054
1055 // Result, m x n.
1056 // All zeros, except given values at (i / 16, 0).
1057 const std::array<std::uint8_t, 4> expected_c_terse{{
1058 142, 159, 182, 213,
1059 }};
1060 assert(expected_c_terse.size() * 16 == m);
1061 std::vector<std::uint8_t> expected_c_data(m * n, 0);
1062 for (int i = 0; i < m; ++i) {
1063 expected_c_data[i] = expected_c_terse[i / 16];
1064 }
1065
1066 const int ldc = m;
1067 // All zeros.
1068 std::vector<std::int32_t> c_offset(m, 0);
1069 // Given values at [i / 16].
1070 const std::array<std::int32_t, 4> c_mult_int_terse{{
1071 3655, 5140, 7049, 9595,
1072 }};
1073 assert(c_mult_int_terse.size() * 16 == m);
1074 std::vector<std::int32_t> c_mult_int(m);
1075 for (int i = 0; i < m; ++i) {
1076 c_mult_int[i] = c_mult_int_terse[i / 16];
1077 }
1078
1079 const int c_shift = 21;
1080
1081 const int c_count = m * n;
1082 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1083 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1084 ldc);
1085 const OffsetColMap result_offset(&c_offset[0], m);
1086 const OffsetColMap result_mult_int(&c_mult_int[0], m);
1087 const int result_shift = c_shift;
1088
1089 GemmContext gemm_context;
1090 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1091 result_offset, result_mult_int, result_shift);
1092 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1093 DefaultL8R8BitDepthParams>(
1094 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1095 output_pipeline);
1096
1097 ResultStats stats;
1098 GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1099
1100 ResultStatsBounds bounds;
1101 const bool good = CheckResultStatsBounds(stats, bounds);
1102 printf("TestMultithreadedPerChannelQuantization: %s\n",
1103 good ? "PASS" : "FAIL");
1104 ReportResultStats(stats, bounds);
1105 Check(good);
1106 }
1107
1108 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1109 void TestWithSmallData() {
1110 const int m = 4;
1111 const int n = 2;
1112 const int k = 3;
1113 // Matrix A (LHS) is:
1114 // | 7 | 10 | 13 | 16 |
1115 // | 8 | 11 | 14 | 17 |
1116 // | 9 | 12 | 15 | 18 |
1117 const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1118 // Matrix B (RHS) is:
1119 // | 1 | 3 | 5 |
1120 // | 2 | 4 | 6 |
1121 const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1122 // Here are the results we expect, from hand calculations:
1123 // (1 * 7) + (3 * 8) + (5 * 9) = 76
1124 // (2 * 7) + (4 * 8) + (6 * 9) = 100
1125 // (1 * 10) + (3 * 11) + (5 * 12) = 103
1126 // (2 * 10) + (4 * 11) + (6 * 12) = 136
1127 // (1 * 13) + (3 * 14) + (5 * 15) = 130
1128 // (2 * 13) + (4 * 14) + (6 * 15) = 172
1129 // (1 * 16) + (3 * 17) + (5 * 18) = 157
1130 // (2 * 16) + (4 * 17) + (6 * 18) = 208
1131 // That means matrix C should be:
1132 // | 76 | 103 | 130 | 157 |
1133 // | 100 | 136 | 172 | 208 |
1134 const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1135
1136 const int c_count = m * n;
1137 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1138
1139 const bool is_a_transposed = true;
1140 const bool is_b_transposed = true;
1141 const bool is_c_transposed = true;
1142 const int lda = k;
1143 const int ldb = n;
1144 const int ldc = n;
1145
1146 const int a_offset = 0;
1147 const int b_offset = 0;
1148 const int c_offset = 0;
1149 const int c_mult = 1;
1150 const int c_shift = 0;
1151
1152 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1153 is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1154 a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1155 c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1156
1157 ResultStats stats;
1158 GetResultStats(output_data.get(), expected_data, c_count, &stats);
1159
1160 ResultStatsBounds bounds;
1161 const bool good = CheckResultStatsBounds(stats, bounds);
1162 printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1163 ReportResultStats(stats, bounds);
1164 Check(good);
1165 }
1166
1167 // This is the most realistic test of how we'll be using the low-precision GEMM
1168 // function in applications. It takes in large input matrices that have been
1169 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1170 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1171 int tolerance_median, int tolerance_max) {
1172 std::unique_ptr<std::uint8_t[]> output_data(
1173 new std::uint8_t[test_data::c_count]);
1174 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1175 test_data::is_a_transposed, test_data::is_b_transposed,
1176 test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1177 test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1178 test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1179 test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1180
1181 ResultStats stats;
1182 GetResultStats(output_data.get(), test_data::expected_c_data,
1183 test_data::c_count, &stats);
1184
1185 ResultStatsBounds bounds;
1186 if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1187 bounds.med_unsigned_diff = tolerance_median;
1188 bounds.max_unsigned_diff = tolerance_max;
1189 bounds.med_signed_diff = 0;
1190 bounds.mean_signed_diff = 0.2f;
1191 }
1192
1193 const bool good = CheckResultStatsBounds(stats, bounds);
1194 printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1195 GetBitDepthName(BitDepth));
1196 ReportResultStats(stats, bounds);
1197 Check(good);
1198 }
1199
1200 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1201 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1202 int result_mult_int, int result_shift) {
1203 Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1204 Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1205 Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1206 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1207 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1208 const int lhs_offset = 12;
1209 const int rhs_offset = -34;
1210
1211 // Test an empty pipeline, i.e. returning raw int32 accumulators.
1212 auto empty_pipeline = std::make_tuple();
1213 GemmContext context;
1214 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1215 &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1216 rhs_offset, empty_pipeline);
1217
1218 for (int r = 0; r < rows; r++) {
1219 for (int c = 0; c < cols; c++) {
1220 std::int32_t expected = 0;
1221 for (int d = 0; d < depth; d++) {
1222 std::int32_t lhs_val =
1223 static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1224 std::int32_t rhs_val =
1225 static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1226 expected += lhs_val * rhs_val;
1227 }
1228 Check(expected == result_raw_int32(r, c));
1229 }
1230 }
1231
1232 // Test a pipeline with only the quantize-down stage, still returning
1233 // unclamped (but scaled) int32's
1234 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1235 quantize_down_stage.result_offset = result_offset;
1236 quantize_down_stage.result_mult_int = result_mult_int;
1237 quantize_down_stage.result_shift = result_shift;
1238 auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1239 Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1240 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1241 &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1242 lhs_offset, rhs_offset, quantize_down_pipeline);
1243
1244 std::int64_t sum = 0;
1245 for (int r = 0; r < rows; r++) {
1246 for (int c = 0; c < cols; c++) {
1247 std::int32_t raw = result_raw_int32(r, c);
1248 std::int32_t expected = RoundingDivideByPOT(
1249 (raw + result_offset) * result_mult_int, result_shift);
1250 Check(expected == result_quantized_down_int32(r, c));
1251 sum += expected;
1252 }
1253 }
1254 std::int64_t avg = sum / (rows * cols);
1255 // Test that the average quantized-down value falls reasonably in the
1256 // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1257 // adjusted.
1258 Check(avg >= 64 && avg <= 192);
1259
1260 // Test the familiar default pipeline consisting of quantize-down and
1261 // clamp-and-cast-to-uint8.
1262 OutputStageSaturatingCastToUint8 saturating_cast_stage;
1263 auto quantize_down_and_saturating_cast_pipeline =
1264 std::make_tuple(quantize_down_stage, saturating_cast_stage);
1265 Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1266 cols);
1267 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1268 &context, lhs.const_map(), rhs.const_map(),
1269 &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1270 quantize_down_and_saturating_cast_pipeline);
1271
1272 for (int r = 0; r < rows; r++) {
1273 for (int c = 0; c < cols; c++) {
1274 std::int32_t quantized = result_quantized_down_int32(r, c);
1275 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1276 Check(expected == result_quantized_down_saturated_uint8(r, c));
1277 }
1278 }
1279
1280 // Test a variant of the familiar default pipeline consisting of quantize-down
1281 // and clamp-and-cast-to-int16.
1282 OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
1283 auto quantize_down_and_saturating_cast_int16_pipeline =
1284 std::make_tuple(quantize_down_stage, saturating_cast_int16_stage);
1285 Matrix<std::int16_t, ResultOrder> result_quantized_down_saturated_int16(rows,
1286 cols);
1287 GemmWithOutputPipeline<std::uint8_t, std::int16_t, DefaultL8R8BitDepthParams>(
1288 &context, lhs.const_map(), rhs.const_map(),
1289 &result_quantized_down_saturated_int16, lhs_offset, rhs_offset,
1290 quantize_down_and_saturating_cast_int16_pipeline);
1291
1292 for (int r = 0; r < rows; r++) {
1293 for (int c = 0; c < cols; c++) {
1294 std::int32_t quantized = result_quantized_down_int32(r, c);
1295 std::int16_t expected = std::min(std::max(quantized, -32768), 32767);
1296 Check(expected == result_quantized_down_saturated_int16(r, c));
1297 }
1298 }
1299
1300 #ifdef GEMMLOWP_MSA
1301 // Test a pipeline consisting of quantize-down and truncating-cast-to-uint8.
1302 OutputStageTruncatingCastToUint8 truncating_cast_stage;
1303 auto quantize_down_and_truncating_cast_pipeline =
1304 std::make_tuple(quantize_down_stage, truncating_cast_stage);
1305 Matrix<std::uint8_t, ResultOrder> result_quantized_down_truncated_uint8(
1306 rows, cols);
1307 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1308 &context, lhs.const_map(), rhs.const_map(),
1309 &result_quantized_down_truncated_uint8, lhs_offset, rhs_offset,
1310 quantize_down_and_truncating_cast_pipeline);
1311
1312 for (int r = 0; r < rows; r++) {
1313 for (int c = 0; c < cols; c++) {
1314 std::int32_t quantized = result_quantized_down_int32(r, c);
1315 std::uint8_t expected = quantized & 255;
1316 Check(expected == result_quantized_down_truncated_uint8(r, c));
1317 }
1318 }
1319 #endif
1320
1321 // Test a bias-addition with row-vector
1322 std::vector<std::int32_t> row_vector_data(cols);
1323 std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1324 500);
1325 for (int i = 0; i < cols; i++) {
1326 row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1327 }
1328 typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1329 RowVectorMap row_vector_map(row_vector_data.data(), cols);
1330 OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1331 row_bias_addition_stage.bias_vector = row_vector_map;
1332 auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1333 Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1334 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1335 &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1336 lhs_offset, rhs_offset, row_bias_addition_pipeline);
1337 for (int r = 0; r < rows; r++) {
1338 for (int c = 0; c < cols; c++) {
1339 std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1340 Check(expected == result_of_row_bias_addition(r, c));
1341 }
1342 }
1343
1344 // Test a bias-addition with column-vector
1345 std::vector<std::int32_t> col_vector_data(rows);
1346 for (int i = 0; i < rows; i++) {
1347 col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1348 }
1349 typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1350 ColVectorMap col_vector_map(col_vector_data.data(), rows);
1351 OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1352 col_bias_addition_stage.bias_vector = col_vector_map;
1353 auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1354 Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1355 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1356 &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1357 lhs_offset, rhs_offset, col_bias_addition_pipeline);
1358 for (int r = 0; r < rows; r++) {
1359 for (int c = 0; c < cols; c++) {
1360 std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1361 Check(expected == result_of_col_bias_addition(r, c));
1362 }
1363 }
1364
1365 // Test a clamp
1366 OutputStageClamp clamp_stage;
1367 // Determine min and max of raw int32 accumulators
1368 std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1369 std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1370 for (int r = 0; r < rows; r++) {
1371 for (int c = 0; c < cols; c++) {
1372 raw_min = std::min(raw_min, result_raw_int32(r, c));
1373 raw_max = std::max(raw_max, result_raw_int32(r, c));
1374 }
1375 }
1376 // Pick some interesting clamp min/max bounds
1377 clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1378 clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1379 assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1380 clamp_stage.max <= raw_max);
1381 auto clamp_pipeline = std::make_tuple(clamp_stage);
1382 Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1383 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1384 &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1385 rhs_offset, clamp_pipeline);
1386 for (int r = 0; r < rows; r++) {
1387 for (int c = 0; c < cols; c++) {
1388 std::int32_t raw = result_raw_int32(r, c);
1389 std::int32_t expected =
1390 std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1391 Check(expected == result_clamped(r, c));
1392 }
1393 }
1394
1395 // Test tanh
1396 OutputStageTanh tanh_stage;
1397 const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1398 const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1399 tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1400 tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1401 auto tanh_pipeline = std::make_tuple(tanh_stage);
1402 Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1403 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1404 &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1405 rhs_offset, tanh_pipeline);
1406 for (int r = 0; r < rows; r++) {
1407 for (int c = 0; c < cols; c++) {
1408 std::int32_t raw = result_raw_int32(r, c);
1409 double real_input =
1410 double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1411 double expected = std::tanh(real_input);
1412 std::int32_t actual_int32 = result_tanh(r, c);
1413 double actual =
1414 double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1415 Check(std::abs(expected - actual) < 2e-4);
1416 }
1417 }
1418
1419 // Test a pipeline with bias and clamp
1420 auto bias_clamp_pipeline =
1421 std::make_tuple(col_bias_addition_stage, clamp_stage);
1422 Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1423 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1424 &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1425 lhs_offset, rhs_offset, bias_clamp_pipeline);
1426 for (int r = 0; r < rows; r++) {
1427 for (int c = 0; c < cols; c++) {
1428 std::int32_t raw = result_raw_int32(r, c);
1429 std::int32_t biased = raw + col_vector_data[r];
1430 std::int32_t expected =
1431 std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1432 Check(expected == result_biased_clamped(r, c));
1433 }
1434 }
1435
1436 // Test a full pipeline with bias and clamp and quantization down to 8bit
1437 // result
1438 auto bias_clamp_quantize_cast_pipeline =
1439 std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1440 saturating_cast_stage);
1441 Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1442 rows, cols);
1443 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1444 &context, lhs.const_map(), rhs.const_map(),
1445 &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1446 bias_clamp_quantize_cast_pipeline);
1447 for (int r = 0; r < rows; r++) {
1448 for (int c = 0; c < cols; c++) {
1449 std::int32_t quantized = RoundingDivideByPOT(
1450 (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1451 result_shift);
1452 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1453 Check(expected == result_biased_clamped_quantized_casted(r, c));
1454 }
1455 }
1456
1457 // Test a pipeline with the fixed-point-multiplier variant stage for the
1458 // quantizing down of 32bit accumulators.
1459 //
1460 // First, figure appropriate fixedpoint multiplier and shift values.
1461 std::int32_t result_fixedpoint_multiplier = result_mult_int;
1462 std::int32_t result_fixedpoint_shift = result_shift;
1463 Check(result_mult_int > 0);
1464 Check(result_shift > 0);
1465 result_fixedpoint_multiplier = result_mult_int;
1466 result_fixedpoint_shift = result_shift - 31;
1467 while (result_fixedpoint_multiplier < (1 << 30)) {
1468 result_fixedpoint_multiplier <<= 1;
1469 result_fixedpoint_shift++;
1470 }
1471 Check(result_fixedpoint_shift >= 0);
1472 // Now test OutputStageQuantizeDownInt32ByFixedPoint
1473 OutputStageQuantizeDownInt32ByFixedPoint
1474 quantize_down_by_fixedpoint_stage;
1475 quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1476 static_cast<std::int32_t>(
1477 round(static_cast<double>(result_offset * result_mult_int) /
1478 (1 << result_shift)));
1479 quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1480 result_fixedpoint_multiplier;
1481 quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1482 auto quantize_down_by_fixedpoint_pipeline =
1483 std::make_tuple(quantize_down_by_fixedpoint_stage);
1484 Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1485 rows, cols);
1486 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1487 &context, lhs.const_map(), rhs.const_map(),
1488 &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1489 quantize_down_by_fixedpoint_pipeline);
1490
1491 for (int r = 0; r < rows; r++) {
1492 for (int c = 0; c < cols; c++) {
1493 const std::int32_t actual =
1494 result_quantized_down_by_fixedpoint_int32(r, c);
1495 const std::int32_t raw = result_raw_int32(r, c);
1496 const std::int32_t expected =
1497 quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1498 RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1499 raw, result_fixedpoint_multiplier),
1500 result_fixedpoint_shift);
1501 Check(actual == expected);
1502 }
1503 }
1504
1505 // Test OutputStageScaleInt32ByFixedPointAndExponent
1506 for (int exponent = -2; exponent <= 2; exponent++) {
1507 OutputStageScaleInt32ByFixedPointAndExponent
1508 scale_by_fixedpoint_and_exponent_stage;
1509 scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift =
1510 static_cast<std::int32_t>(round(static_cast<double>(
1511 result_offset * result_mult_int * std::pow(2.0, exponent))));
1512 scale_by_fixedpoint_and_exponent_stage.result_fixedpoint_multiplier =
1513 result_fixedpoint_multiplier;
1514 scale_by_fixedpoint_and_exponent_stage.result_exponent = exponent;
1515 auto scale_by_fixedpoint_and_exponent_pipeline =
1516 std::make_tuple(scale_by_fixedpoint_and_exponent_stage);
1517 Matrix<std::int32_t, ResultOrder>
1518 result_scaled_by_fixedpoint_and_exponent_int32(rows, cols);
1519 GemmWithOutputPipeline<std::uint8_t, std::int32_t,
1520 DefaultL8R8BitDepthParams>(
1521 &context, lhs.const_map(), rhs.const_map(),
1522 &result_scaled_by_fixedpoint_and_exponent_int32, lhs_offset, rhs_offset,
1523 scale_by_fixedpoint_and_exponent_pipeline);
1524
1525 for (int r = 0; r < rows; r++) {
1526 for (int c = 0; c < cols; c++) {
1527 const std::int32_t actual =
1528 result_scaled_by_fixedpoint_and_exponent_int32(r, c);
1529 const std::int32_t raw = result_raw_int32(r, c);
1530 int left_shift = std::max(0, exponent);
1531 int right_shift = std::max(0, -exponent);
1532 const std::int32_t expected =
1533 scale_by_fixedpoint_and_exponent_stage.result_offset_after_shift +
1534 RoundingDivideByPOT(
1535 SaturatingRoundingDoublingHighMul((1 << left_shift) * raw,
1536 result_fixedpoint_multiplier),
1537 right_shift);
1538 Check(actual == expected);
1539 }
1540 }
1541 }
1542
1543 // Test the variant of the familiar default pipeline consisting of
1544 // quantize-down and
1545 // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1546 // downscaling.
1547 auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1548 std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1549 Matrix<std::uint8_t, ResultOrder>
1550 result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1551 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1552 &context, lhs.const_map(), rhs.const_map(),
1553 &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1554 rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1555
1556 for (int r = 0; r < rows; r++) {
1557 for (int c = 0; c < cols; c++) {
1558 std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1559 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1560 Check(expected ==
1561 result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1562 }
1563 }
1564
1565 printf("TestOutputStages: PASS with ResultOrder=%s\n",
1566 OrderName(ResultOrder));
1567 }
1568
1569 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1570 template <typename BitDepthParams>
TestExhaustively()1571 void TestExhaustively() {
1572 GemmContext context;
1573
1574 // Test the internal GEMM interfaces
1575 test_gemm<
1576 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1577 std::uint8_t, BitDepthParams>>(&context);
1578
1579 test_gemm<
1580 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1581 std::uint8_t, BitDepthParams>>(&context);
1582
1583 // Test the public GEMM interfaces
1584 test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1585
1586 // Test GEMV cases (internal interfaces)
1587 test_gemv<
1588 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1589 std::uint8_t, BitDepthParams>>(&context);
1590
1591 test_gemv<
1592 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1593 std::uint8_t, BitDepthParams>>(&context);
1594
1595 // Test GEMV cases (public interfaces)
1596 test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1597 }
1598
1599 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1600 void TestExhaustivelyEightBitIntGemm() {
1601 GemmContext context;
1602 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1603 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1604 test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1605 }
1606
TestKernels()1607 void TestKernels() {
1608 GemmContext context;
1609
1610 // Test specific kernels with various different formats,
1611 // to exercises corner cases especially in the packing code.
1612 test_gemm_kernel<
1613 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1614 KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1615 &context);
1616
1617 test_gemm_kernel<
1618 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1619 KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1620 &context);
1621
1622 test_gemm_kernel<
1623 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1624 KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1625 &context);
1626
1627 test_gemm_kernel<ReferenceKernel<KernelFormat<
1628 KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1629 KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1630
1631 test_gemm_kernel<ReferenceKernel<KernelFormat<
1632 KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1633 KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1634
1635 test_gemm_kernel<ReferenceKernel<KernelFormat<
1636 KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1637 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1638
1639 test_gemm_kernel<ReferenceKernel<KernelFormat<
1640 KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1641 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1642
1643 test_gemm_kernel<ReferenceKernel<KernelFormat<
1644 KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1645 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1646
1647 test_gemm_kernel<ReferenceKernel<KernelFormat<
1648 KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1649 KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1650 }
1651
1652 #endif // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1653
1654 template <typename BitDepthParams>
TestOutputStages()1655 void TestOutputStages() {
1656 // Test non-default output pipelines with various combinations of
1657 // output stages.
1658 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1659 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1660 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1661 14);
1662 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1663 14);
1664 }
1665
test()1666 void test() {
1667 #ifdef GEMMLOWP_TEST_PROFILE
1668 RegisterCurrentThreadForProfiling();
1669 StartProfiling();
1670 #endif
1671
1672 // Run a first quick test against hand-calculated data.
1673 TestWithSmallData();
1674
1675 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1676 TestExhaustively<DefaultL8R8BitDepthParams>();
1677 TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1678 TestExhaustively<DefaultL7R5BitDepthParams>(); // legacy, same as L8R8
1679 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1680 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1681 TestKernels();
1682 #endif
1683
1684 // Run against actual data from a network evaluation.
1685 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1686 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1687
1688 // Test non-default output pipelines with various combinations of
1689 // output stages.
1690 TestOutputStages<DefaultL8R8BitDepthParams>();
1691 TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1692
1693 // Test per channel quantization.
1694 TestWithSmallDataPerChannelQuantization();
1695 TestWithLargeDataPerChannelQuantization();
1696 TestMultithreadedPerChannelQuantization();
1697 #ifdef GEMMLOWP_TEST_PROFILE
1698 FinishProfiling();
1699 #endif
1700
1701 std::cerr << "All tests passed." << std::endl;
1702
1703 // We have been testing the eight_bit_int_gemm, so we should free its
1704 // persistent
1705 // resources now to avoid having leak-checking tools report leaks.
1706 eight_bit_int_gemm::FreePersistentResources();
1707 }
1708
1709 } // end namespace gemmlowp
1710
1711 // For iOS, we need to define our own main(), so skip it here.
1712 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1713 int main() { gemmlowp::test(); }
1714 #endif
1715