1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test.h"
16
17 #include <array>
18 #include <cstdint>
19 #include <cstdlib>
20 #include <ctime>
21 #include <iostream>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 #ifdef __APPLE__
26 #include <TargetConditionals.h>
27 #endif
28
29 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
30 #include "../internal/kernel_reference.h"
31 #include "test_data.h"
32
33 namespace gemmlowp {
34
ReferenceEightBitIntGemm(bool transpose_a,bool transpose_b,bool transpose_c,int m,int n,int k,const std::uint8_t * a,std::int32_t a_offset,int lda,const std::uint8_t * b,std::int32_t b_offset,int ldb,std::uint8_t * c,std::int32_t c_offset,std::int32_t c_mult_int,std::int32_t c_shift,int ldc)35 void ReferenceEightBitIntGemm(bool transpose_a, bool transpose_b,
36 bool transpose_c, int m, int n, int k,
37 const std::uint8_t* a, std::int32_t a_offset,
38 int lda, const std::uint8_t* b,
39 std::int32_t b_offset, int ldb, std::uint8_t* c,
40 std::int32_t c_offset, std::int32_t c_mult_int,
41 std::int32_t c_shift, int ldc) {
42 ScopedProfilingLabel("ReferenceEightBitIntGemm");
43 assert((c_shift >= 0) && (c_shift <= 32));
44
45 assert(a != nullptr);
46 assert(b != nullptr);
47 assert(c != nullptr);
48
49 int a_i_stride;
50 int a_l_stride;
51 if (transpose_a) {
52 a_i_stride = lda;
53 a_l_stride = 1;
54 } else {
55 a_i_stride = 1;
56 a_l_stride = lda;
57 }
58 int b_j_stride;
59 int b_l_stride;
60 if (transpose_b) {
61 b_j_stride = 1;
62 b_l_stride = ldb;
63 } else {
64 b_j_stride = ldb;
65 b_l_stride = 1;
66 }
67 int c_i_stride;
68 int c_j_stride;
69 if (transpose_c) {
70 c_i_stride = ldc;
71 c_j_stride = 1;
72 } else {
73 c_i_stride = 1;
74 c_j_stride = ldc;
75 }
76 int i, j, l;
77
78 const std::int32_t kRoundingTerm = (c_shift < 1) ? 0 : (1 << (c_shift - 1));
79
80 for (j = 0; j < n; j++) {
81 for (i = 0; i < m; i++) {
82 std::int32_t total = 0;
83 for (l = 0; l < k; l++) {
84 const int a_index = i * a_i_stride + l * a_l_stride;
85 const std::uint8_t a_as_byte = a[a_index];
86 const std::int32_t a_as_int =
87 static_cast<std::int32_t>(a_as_byte) + a_offset;
88 const int b_index = j * b_j_stride + l * b_l_stride;
89 const std::uint8_t b_as_byte = b[b_index];
90 const std::int32_t b_as_int =
91 static_cast<std::int32_t>(b_as_byte) + b_offset;
92 const std::int32_t mult_as_int = a_as_int * b_as_int;
93 total += mult_as_int;
94 }
95 std::int32_t output =
96 (((total + c_offset) * c_mult_int) + kRoundingTerm) >> c_shift;
97 if (output > 255) {
98 output = 255;
99 }
100 if (output < 0) {
101 output = 0;
102 }
103 const int c_index = i * c_i_stride + j * c_j_stride;
104 c[c_index] = static_cast<std::uint8_t>(output);
105 }
106 }
107 }
108
109 typedef VectorMap<const std::int32_t, VectorShape::Col> OffsetColMap;
110 typedef VectorMap<const std::int32_t, VectorShape::Row> OffsetRowMap;
111 typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
112 typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
113
114 // *GemmWrapper's allow to wrap various Gemm functions in a uniform
115 // interface, so we can use the same testing code to test all of them
116
117 template <typename Kernel, typename Scalar, typename tBitDepthParams>
118 struct SingleThreadGemmWrapper {
119 typedef tBitDepthParams BitDepthParams;
120
Namegemmlowp::SingleThreadGemmWrapper121 static const char* Name() {
122 static char buf[256];
123 snprintf(buf, sizeof(buf), "SingleThreadGemm, Kernel: %s", Kernel().Name());
124 return buf;
125 }
126
127 typedef SingleThreadGemmContext Context;
128
129 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::SingleThreadGemmWrapper130 static bool Gemm(Context* context,
131 const MatrixMap<const Scalar, LhsOrder>& lhs,
132 const MatrixMap<const Scalar, RhsOrder>& rhs,
133 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
134 int rhs_offset, int result_offset, int result_mult_int,
135 int result_shift) {
136 ScopedProfilingLabel("SingleThreadGemmWrapper::Gemm");
137 const int rows = lhs.rows();
138 const int cols = rhs.cols();
139 if (rows < cols) {
140 // SingleThreadGemm is never called with rows < cols.
141 // That case is handled earlier.
142 return false;
143 }
144 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
145 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
146 SingleThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
147 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
148 OffsetRowDup>(
149 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
150 rhs_offset_vector,
151 MakeStandardOutputPipeline(result_offset, result_mult_int,
152 result_shift));
153 return true;
154 }
155 };
156
157 template <typename Kernel, typename Scalar, typename tBitDepthParams>
158 struct MultiThreadGemmWrapper {
159 typedef tBitDepthParams BitDepthParams;
160
Namegemmlowp::MultiThreadGemmWrapper161 static const char* Name() {
162 static char buf[256];
163 snprintf(buf, sizeof(buf), "MultiThreadGemm, Kernel: %s", Kernel().Name());
164 return buf;
165 }
166
167 typedef MultiThreadGemmContext Context;
168
169 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::MultiThreadGemmWrapper170 static bool Gemm(Context* context,
171 const MatrixMap<const Scalar, LhsOrder>& lhs,
172 const MatrixMap<const Scalar, RhsOrder>& rhs,
173 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
174 int rhs_offset, int result_offset, int result_mult_int,
175 int result_shift) {
176 ScopedProfilingLabel("MultiThreadGemmWrapper::Gemm");
177 context->set_max_num_threads(0);
178 const int rows = lhs.rows();
179 const int cols = rhs.cols();
180 if (rows < cols) {
181 // SingleThreadGemm is never called with rows < cols.
182 // That case is handled earlier.
183 return false;
184 }
185 const OffsetColDup lhs_offset_vector(lhs_offset, rows);
186 const OffsetRowDup rhs_offset_vector(rhs_offset, cols);
187 MultiThreadGemm<typename Kernel::Format, Scalar, Scalar, BitDepthParams,
188 LhsOrder, RhsOrder, ResultOrder, OffsetColDup,
189 OffsetRowDup>(
190 context, Kernel(), lhs, rhs, result, lhs_offset_vector,
191 rhs_offset_vector,
192 MakeStandardOutputPipeline(result_offset, result_mult_int,
193 result_shift));
194 return true;
195 }
196 };
197
198 template <typename Scalar, typename tBitDepthParams>
199 struct PublicGemmWrapper {
200 typedef tBitDepthParams BitDepthParams;
201
Namegemmlowp::PublicGemmWrapper202 static const char* Name() { return "public Gemm"; }
203
204 typedef GemmContext Context;
205
206 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::PublicGemmWrapper207 static bool Gemm(Context* context,
208 const MatrixMap<const Scalar, LhsOrder>& lhs,
209 const MatrixMap<const Scalar, RhsOrder>& rhs,
210 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
211 int rhs_offset, int result_offset, int result_mult_int,
212 int result_shift) {
213 ScopedProfilingLabel("PublicGemmWrapper::Gemm");
214 gemmlowp::Gemm<std::uint8_t, BitDepthParams, LhsOrder, RhsOrder,
215 ResultOrder>(context, lhs, rhs, result, lhs_offset,
216 rhs_offset, result_offset, result_mult_int,
217 result_shift);
218 return true;
219 }
220 };
221
222 template <eight_bit_int_gemm::BitDepthSetting BitDepth>
223 struct BitDepthParamsForSettings {};
224
225 template <>
226 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A8B8>
227 : DefaultL8R8BitDepthParams {};
228
229 template <>
230 struct BitDepthParamsForSettings<eight_bit_int_gemm::BitDepthSetting::A5B7>
231 : DefaultL7R5BitDepthParams {};
232
233 template <typename Scalar, eight_bit_int_gemm::BitDepthSetting BitDepth>
234 struct EightBitIntGemmWrapper {
235 typedef BitDepthParamsForSettings<BitDepth> BitDepthParams;
236
Namegemmlowp::EightBitIntGemmWrapper237 static const char* Name() { return "EightBitIntGemm"; }
238
239 typedef void Context;
240
241 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::EightBitIntGemmWrapper242 static bool Gemm(Context*, const MatrixMap<const Scalar, LhsOrder>& lhs,
243 const MatrixMap<const Scalar, RhsOrder>& rhs,
244 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
245 int rhs_offset, int result_offset, int result_mult_int,
246 int result_shift) {
247 ScopedProfilingLabel("EightBitIntGemmWrapper::Gemm");
248 const bool transpose_c = ResultOrder == MapOrder::RowMajor;
249 const bool transpose_a = LhsOrder == MapOrder::RowMajor;
250 const bool transpose_b = RhsOrder == MapOrder::RowMajor;
251 eight_bit_int_gemm::EightBitIntGemm(
252 transpose_a, transpose_b, transpose_c, lhs.rows(), rhs.cols(),
253 lhs.cols(), lhs.data(), lhs_offset, lhs.stride(), rhs.data(),
254 rhs_offset, rhs.stride(), result->data(), result_offset,
255 result_mult_int, result_shift, result->stride(), BitDepth);
256 return true;
257 }
258 };
259
260 template <typename Scalar>
261 struct ReferenceEightBitIntGemmWrapper {
262 typedef DefaultL8R8BitDepthParams BitDepthParams;
263
Namegemmlowp::ReferenceEightBitIntGemmWrapper264 static const char* Name() { return "ReferenceEightBitIntGemm"; }
265
266 template <MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
Gemmgemmlowp::ReferenceEightBitIntGemmWrapper267 static bool Gemm(bool transpose_a, bool transpose_b, bool transpose_c,
268 const MatrixMap<const Scalar, LhsOrder>& lhs,
269 const MatrixMap<const Scalar, RhsOrder>& rhs,
270 MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
271 int rhs_offset, int result_offset, int result_mult_int,
272 int result_shift) {
273 ScopedProfilingLabel("ReferenceEightBitIntGemmWrapper::Gemm");
274 ReferenceEightBitIntGemm(transpose_a, transpose_b, transpose_c, lhs.rows(),
275 rhs.cols(), lhs.cols(), lhs.data(), lhs_offset,
276 lhs.stride(), rhs.data(), rhs_offset, rhs.stride(),
277 result->data(), result_offset, result_mult_int,
278 result_shift, result->stride());
279 return true;
280 }
281 };
282
OrderName(MapOrder order)283 const char* OrderName(MapOrder order) {
284 return order == MapOrder::ColMajor ? "ColMajor" : "RowMajor";
285 }
286
287 struct ResultStats {
ResultStatsgemmlowp::ResultStats288 ResultStats()
289 : count(0),
290 med_val(0),
291 mean_signed_diff(0),
292 med_signed_diff(0),
293 med_unsigned_diff(0),
294 max_unsigned_diff(0) {}
295
296 int count;
297 int med_val;
298 float mean_signed_diff;
299 int med_signed_diff;
300 int med_unsigned_diff;
301 int max_unsigned_diff;
302
303 std::vector<int> count_diff_by_pot_slice;
304 };
305
GetResultStats(const std::uint8_t * actual,const std::uint8_t * expected,size_t count,ResultStats * stats)306 void GetResultStats(const std::uint8_t* actual, const std::uint8_t* expected,
307 size_t count, ResultStats* stats) {
308 ScopedProfilingLabel("GetResultStats");
309 std::vector<std::uint8_t> results;
310 std::vector<std::int16_t> signed_diffs;
311 std::vector<std::uint8_t> unsigned_diffs;
312 std::int64_t signed_diffs_sum = 0;
313 for (size_t i = 0; i < count; i++) {
314 results.push_back(actual[i]);
315 std::int16_t signed_diff = actual[i] - expected[i];
316 signed_diffs.push_back(signed_diff);
317 unsigned_diffs.push_back(std::abs(signed_diff));
318 signed_diffs_sum += signed_diff;
319 }
320
321 std::sort(results.begin(), results.end());
322 std::sort(signed_diffs.begin(), signed_diffs.end());
323 std::sort(unsigned_diffs.begin(), unsigned_diffs.end());
324
325 const size_t middle = count / 2;
326
327 stats->count = count;
328 stats->med_val = results[middle];
329 stats->mean_signed_diff = float(signed_diffs_sum) / count;
330 stats->med_signed_diff = signed_diffs[middle];
331 stats->med_unsigned_diff = unsigned_diffs[middle];
332 stats->max_unsigned_diff = unsigned_diffs.back();
333
334 // Size 9 for 9 different POT values: 2^0, ..., 2^8
335 stats->count_diff_by_pot_slice.resize(9);
336 auto cur = unsigned_diffs.begin();
337 size_t checksum = 0;
338 for (int exponent = 0; exponent < 9; exponent++) {
339 int pot = 1 << exponent;
340 auto next = std::lower_bound(cur, unsigned_diffs.end(), pot);
341 checksum += stats->count_diff_by_pot_slice[exponent] = next - cur;
342 cur = next;
343 }
344 assert(checksum == count);
345 }
346
347 struct ResultStatsBounds {
ResultStatsBoundsgemmlowp::ResultStatsBounds348 ResultStatsBounds()
349 : mean_signed_diff(0),
350 med_signed_diff(0),
351 med_unsigned_diff(0),
352 max_unsigned_diff(0) {}
353
354 float mean_signed_diff;
355 int med_signed_diff;
356 int med_unsigned_diff;
357 int max_unsigned_diff;
358 };
359
CheckResultStatsBounds(const ResultStats & stats,const ResultStatsBounds & bounds)360 bool CheckResultStatsBounds(const ResultStats& stats,
361 const ResultStatsBounds& bounds) {
362 return stats.max_unsigned_diff <= bounds.max_unsigned_diff &&
363 stats.med_unsigned_diff <= bounds.med_unsigned_diff &&
364 std::abs(stats.med_signed_diff) <= bounds.med_signed_diff &&
365 std::abs(stats.mean_signed_diff) <= bounds.mean_signed_diff;
366 }
367
ReportResultStats(const ResultStats & stats,const ResultStatsBounds & bounds)368 void ReportResultStats(const ResultStats& stats,
369 const ResultStatsBounds& bounds) {
370 printf(" number of matrix entries: %d\n", stats.count);
371 printf(" median value: %d\n", stats.med_val);
372 printf(" median unsigned diff: %d (tolerating %d)\n",
373 stats.med_unsigned_diff, bounds.med_unsigned_diff);
374 printf(" max unsigned diff: %d (tolerating %d)\n", stats.max_unsigned_diff,
375 bounds.max_unsigned_diff);
376 printf(" median signed diff: %d (tolerating %d)\n", stats.med_signed_diff,
377 bounds.med_signed_diff);
378 printf(" mean signed diff: %.3g (tolerating %.3g)\n",
379 stats.mean_signed_diff, bounds.mean_signed_diff);
380
381 printf("No error: %.2f %% of entries\n",
382 100.f * stats.count_diff_by_pot_slice[0] / stats.count);
383 for (int exponent = 1; exponent < 9; exponent++) {
384 printf("Error in %d..%d range: %.2f %% of entries\n", 1 << (exponent - 1),
385 (1 << exponent) - 1,
386 100.f * stats.count_diff_by_pot_slice[exponent] / stats.count);
387 }
388 }
389
390 // Our approach to choosing result_shift values for testing, is bisection.
391 // This function takes an interval, [result_shift_min .. result_shift_max].
392 // If too much saturation occurred in either direction, it bisects accordingly,
393 // recursing until the interval contains only one value.
394 // The primary reason why we prefer this over computing optimal shift values,
395 // is that we actually want to exercise some saturation, as there is nontrivial
396 // code handling that in gemmlowp.
397 // Secondarily, this is faster than computing optimal shifts, since in 90% of
398 // cases the first-tried shift value 16 turns out to be good enough.
399 template <typename GemmWrapper, typename LhsType, typename RhsType,
400 typename ResultType>
test_gemm_impl(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int,int result_shift_min,int result_shift_max)401 void test_gemm_impl(typename GemmWrapper::Context* context, const LhsType& lhs,
402 const RhsType& rhs, ResultType* result, int lhs_offset,
403 int rhs_offset, int result_offset, int result_mult_int,
404 int result_shift_min, int result_shift_max) {
405 const int rows = lhs.rows();
406 const int cols = rhs.cols();
407 Check(lhs.cols() == rhs.rows());
408 const int depth = lhs.cols();
409
410 const int result_shift = (result_shift_min + result_shift_max) / 2;
411
412 if (!GemmWrapper::Gemm(context, lhs.const_map(), rhs.const_map(),
413 &result->map(), lhs_offset, rhs_offset, result_offset,
414 result_mult_int, result_shift)) {
415 // Internal GEMM functions are not required to handle all cases
416 // (e.g. rows < cols) as these are supposed to have been handled
417 // ahead of them. Their test wrappers return false in that case.
418 return;
419 }
420
421 typedef typename ResultType::Scalar Scalar;
422 static const MapOrder kLhsOrder = LhsType::kOrder;
423 static const MapOrder kRhsOrder = RhsType::kOrder;
424 static const MapOrder kResultOrder = ResultType::kOrder;
425 ResultType ref_result(rows, cols);
426 const bool transpose_c = kResultOrder == MapOrder::RowMajor;
427 const bool transpose_a = kLhsOrder == MapOrder::RowMajor;
428 const bool transpose_b = kRhsOrder == MapOrder::RowMajor;
429 ReferenceEightBitIntGemmWrapper<Scalar>::Gemm(
430 transpose_a, transpose_b, transpose_c, lhs.const_map(), rhs.const_map(),
431 &ref_result.map(), lhs_offset, rhs_offset, result_offset, result_mult_int,
432 result_shift);
433
434 typedef typename GemmWrapper::BitDepthParams BitDepthParams;
435
436 ResultStats stats;
437 GetResultStats(result->data(), ref_result.data(), rows * cols, &stats);
438
439 // Adjust shifts until we get meaningful results
440 int new_result_shift_min = result_shift_min;
441 int new_result_shift_max = result_shift_max;
442 bool retry = false;
443
444 if (stats.med_val < 32) {
445 new_result_shift_max = (result_shift_min + result_shift_max) / 2;
446 retry = true;
447 }
448
449 if (stats.med_val > 224) {
450 new_result_shift_min = (result_shift_min + result_shift_max) / 2;
451 retry = true;
452 }
453
454 if (retry) {
455 if (result_shift_min != result_shift_max) {
456 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset,
457 rhs_offset, result_offset, result_mult_int,
458 new_result_shift_min, new_result_shift_max);
459 }
460 return;
461 }
462
463 ResultStatsBounds bounds;
464
465 // Check results
466 const bool good = CheckResultStatsBounds(stats, bounds);
467
468 printf(
469 "%s: %dx%dx%d %s x %s -> %s, %s, offsets %d/%d/%d, mult %d, shift %d\n",
470 good ? "PASS" : "FAIL", rows, depth, cols, OrderName(kLhsOrder),
471 OrderName(kRhsOrder), OrderName(kResultOrder), GemmWrapper::Name(),
472 lhs_offset, rhs_offset, result_offset, result_mult_int, result_shift);
473
474 if (!good) {
475 ReportResultStats(stats, bounds);
476
477 int bad_coeffs_printed = 0;
478 for (int c = 0; c < result->cols() && bad_coeffs_printed < 200; c++) {
479 for (int r = 0; r < result->rows() && bad_coeffs_printed < 200; r++) {
480 if (ref_result(r, c) != (*result)(r, c)) {
481 printf("bad coeff: at (%d, %d), expected %d, got %d\n", r, c,
482 ref_result(r, c), (*result)(r, c));
483 bad_coeffs_printed++;
484 }
485 }
486 }
487 }
488
489 Check(good);
490 }
491
492 template <typename GemmWrapper, typename LhsType, typename RhsType,
493 typename ResultType>
test_gemm(typename GemmWrapper::Context * context,const LhsType & lhs,const RhsType & rhs,ResultType * result,int lhs_offset,int rhs_offset,int result_offset,int result_mult_int)494 void test_gemm(typename GemmWrapper::Context* context, const LhsType& lhs,
495 const RhsType& rhs, ResultType* result, int lhs_offset,
496 int rhs_offset, int result_offset, int result_mult_int) {
497 test_gemm_impl<GemmWrapper>(context, lhs, rhs, result, lhs_offset, rhs_offset,
498 result_offset, result_mult_int, 0, 32);
499 }
500
501 enum class WhatParamsToTest {
502 All,
503 OnlyGenericCase,
504 };
505
506 template <typename GemmWrapper, MapOrder LhsOrder, MapOrder RhsOrder,
507 MapOrder ResultOrder>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test)508 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
509 int cols, WhatParamsToTest params_to_test) {
510 typedef std::uint8_t Scalar;
511 typedef Matrix<Scalar, LhsOrder> LhsType;
512 using BitDepthParams = typename GemmWrapper::BitDepthParams;
513 LhsType lhs(rows, depth);
514 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
515 typedef Matrix<Scalar, RhsOrder> RhsType;
516 RhsType rhs(depth, cols);
517 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
518 typedef Matrix<Scalar, ResultOrder> ResultType;
519 ResultType result(rows, cols);
520 MakeZero(&result);
521
522 if (params_to_test == WhatParamsToTest::All) {
523 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 1);
524 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 0, 0, 1);
525 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 10, 0, 1);
526 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 10, 1);
527 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 0, 0, 0, 10);
528 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 10, 10, 10, 10);
529 test_gemm<GemmWrapper>(context, lhs, rhs, &result, 256, 1, 17, 4);
530 }
531 test_gemm<GemmWrapper>(context, lhs, rhs, &result, -75, -91, 74980, 123);
532 }
533
534 enum class WhatOrdersToTest { All, OnlyRCC };
535
536 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context,int rows,int depth,int cols,WhatParamsToTest params_to_test,WhatOrdersToTest orders_to_test)537 void test_gemm(typename GemmWrapper::Context* context, int rows, int depth,
538 int cols, WhatParamsToTest params_to_test,
539 WhatOrdersToTest orders_to_test) {
540 #define GEMMLOWP_ONE_TEST(LhsOrder, RhsOrder, ResultOrder) \
541 do { \
542 test_gemm<GemmWrapper, MapOrder::LhsOrder, MapOrder::RhsOrder, \
543 MapOrder::ResultOrder>(context, rows, depth, cols, \
544 params_to_test); \
545 } while (false)
546
547 if (orders_to_test == WhatOrdersToTest::All) {
548 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, ColMajor);
549 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
550 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, ColMajor);
551 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, ColMajor);
552
553 GEMMLOWP_ONE_TEST(ColMajor, ColMajor, RowMajor);
554 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, RowMajor);
555 GEMMLOWP_ONE_TEST(ColMajor, RowMajor, RowMajor);
556 GEMMLOWP_ONE_TEST(RowMajor, RowMajor, RowMajor);
557 } else {
558 GEMMLOWP_ONE_TEST(RowMajor, ColMajor, ColMajor);
559 }
560
561 #undef GEMMLOWP_ONE_TEST
562 }
563
564 template <typename Kernel>
test_gemm_kernel(MultiThreadGemmContext * context)565 void test_gemm_kernel(MultiThreadGemmContext* context) {
566 typedef MultiThreadGemmWrapper<Kernel, std::uint8_t,
567 DefaultL8R8BitDepthParams>
568 GemmWrapper;
569 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::OnlyGenericCase,
570 WhatOrdersToTest::OnlyRCC);
571 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::OnlyGenericCase,
572 WhatOrdersToTest::OnlyRCC);
573 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::OnlyGenericCase,
574 WhatOrdersToTest::OnlyRCC);
575 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::OnlyGenericCase,
576 WhatOrdersToTest::OnlyRCC);
577 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::OnlyGenericCase,
578 WhatOrdersToTest::OnlyRCC);
579 test_gemm<GemmWrapper>(context, 9, 11, 13, WhatParamsToTest::OnlyGenericCase,
580 WhatOrdersToTest::OnlyRCC);
581 test_gemm<GemmWrapper>(context, 50, 50, 50, WhatParamsToTest::All,
582 WhatOrdersToTest::OnlyRCC);
583 test_gemm<GemmWrapper>(context, 200, 200, 200,
584 WhatParamsToTest::OnlyGenericCase,
585 WhatOrdersToTest::All);
586 test_gemm<GemmWrapper>(context, 50, 5000, 50,
587 WhatParamsToTest::OnlyGenericCase,
588 WhatOrdersToTest::OnlyRCC);
589 }
590
591 template <typename GemmWrapper>
test_gemm(typename GemmWrapper::Context * context)592 void test_gemm(typename GemmWrapper::Context* context) {
593 test_gemm<GemmWrapper>(context, 1, 1, 1, WhatParamsToTest::All,
594 WhatOrdersToTest::OnlyRCC);
595 test_gemm<GemmWrapper>(context, 2, 1, 1, WhatParamsToTest::All,
596 WhatOrdersToTest::OnlyRCC);
597 test_gemm<GemmWrapper>(context, 1, 2, 1, WhatParamsToTest::All,
598 WhatOrdersToTest::OnlyRCC);
599 test_gemm<GemmWrapper>(context, 1, 1, 2, WhatParamsToTest::All,
600 WhatOrdersToTest::OnlyRCC);
601 test_gemm<GemmWrapper>(context, 2, 2, 2, WhatParamsToTest::All,
602 WhatOrdersToTest::OnlyRCC);
603 test_gemm<GemmWrapper>(context, 3, 3, 3, WhatParamsToTest::All,
604 WhatOrdersToTest::OnlyRCC);
605 test_gemm<GemmWrapper>(context, 4, 4, 4, WhatParamsToTest::All,
606 WhatOrdersToTest::OnlyRCC);
607 test_gemm<GemmWrapper>(context, 5, 5, 5, WhatParamsToTest::All,
608 WhatOrdersToTest::OnlyRCC);
609 test_gemm<GemmWrapper>(context, 6, 6, 6, WhatParamsToTest::All,
610 WhatOrdersToTest::OnlyRCC);
611 test_gemm<GemmWrapper>(context, 3, 5, 7, WhatParamsToTest::All,
612 WhatOrdersToTest::OnlyRCC);
613 test_gemm<GemmWrapper>(context, 7, 3, 5, WhatParamsToTest::All,
614 WhatOrdersToTest::OnlyRCC);
615 test_gemm<GemmWrapper>(context, 5, 7, 3, WhatParamsToTest::All,
616 WhatOrdersToTest::OnlyRCC);
617 test_gemm<GemmWrapper>(context, 8, 8, 8, WhatParamsToTest::All,
618 WhatOrdersToTest::All);
619 test_gemm<GemmWrapper>(context, 16, 16, 16, WhatParamsToTest::All,
620 WhatOrdersToTest::OnlyRCC);
621 test_gemm<GemmWrapper>(context, 32, 32, 32, WhatParamsToTest::All,
622 WhatOrdersToTest::OnlyRCC);
623 test_gemm<GemmWrapper>(context, 64, 64, 64, WhatParamsToTest::All,
624 WhatOrdersToTest::All);
625 test_gemm<GemmWrapper>(context, 128, 128, 128, WhatParamsToTest::All,
626 WhatOrdersToTest::OnlyRCC);
627
628 test_gemm<GemmWrapper>(context, 16, 17, 16, WhatParamsToTest::All,
629 WhatOrdersToTest::OnlyRCC);
630 test_gemm<GemmWrapper>(context, 37, 55, 73, WhatParamsToTest::All,
631 WhatOrdersToTest::OnlyRCC);
632 test_gemm<GemmWrapper>(context, 57, 87, 117, WhatParamsToTest::All,
633 WhatOrdersToTest::OnlyRCC);
634 test_gemm<GemmWrapper>(context, 93, 83, 73, WhatParamsToTest::All,
635 WhatOrdersToTest::OnlyRCC);
636 test_gemm<GemmWrapper>(context, 109, 89, 99, WhatParamsToTest::All,
637 WhatOrdersToTest::OnlyRCC);
638 test_gemm<GemmWrapper>(context, 78, 101, 82, WhatParamsToTest::All,
639 WhatOrdersToTest::OnlyRCC);
640
641 test_gemm<GemmWrapper>(context, 512, 512, 512,
642 WhatParamsToTest::OnlyGenericCase,
643 WhatOrdersToTest::OnlyRCC);
644 test_gemm<GemmWrapper>(context, 1024, 1024, 1024,
645 WhatParamsToTest::OnlyGenericCase,
646 WhatOrdersToTest::OnlyRCC);
647 test_gemm<GemmWrapper>(context, 567, 2345, 123,
648 WhatParamsToTest::OnlyGenericCase,
649 WhatOrdersToTest::OnlyRCC);
650 test_gemm<GemmWrapper>(context, 100, 5000, 100,
651 WhatParamsToTest::OnlyGenericCase,
652 WhatOrdersToTest::OnlyRCC);
653 test_gemm<GemmWrapper>(context, 1, 1, 1000, WhatParamsToTest::OnlyGenericCase,
654 WhatOrdersToTest::OnlyRCC);
655 test_gemm<GemmWrapper>(context, 1000, 1, 1, WhatParamsToTest::OnlyGenericCase,
656 WhatOrdersToTest::OnlyRCC);
657 test_gemm<GemmWrapper>(context, 1, 1000, 1, WhatParamsToTest::OnlyGenericCase,
658 WhatOrdersToTest::OnlyRCC);
659 test_gemm<GemmWrapper>(context, 1, 1000, 1000,
660 WhatParamsToTest::OnlyGenericCase,
661 WhatOrdersToTest::OnlyRCC);
662 test_gemm<GemmWrapper>(context, 1000, 1, 1000,
663 WhatParamsToTest::OnlyGenericCase,
664 WhatOrdersToTest::OnlyRCC);
665 test_gemm<GemmWrapper>(context, 1000, 1000, 1,
666 WhatParamsToTest::OnlyGenericCase,
667 WhatOrdersToTest::OnlyRCC);
668 test_gemm<GemmWrapper>(context, 777, 3456, 1,
669 WhatParamsToTest::OnlyGenericCase,
670 WhatOrdersToTest::OnlyRCC);
671 test_gemm<GemmWrapper>(context, 4567, 555, 1,
672 WhatParamsToTest::OnlyGenericCase,
673 WhatOrdersToTest::OnlyRCC);
674
675 // Test all storage orders
676 test_gemm<GemmWrapper>(context, 70, 90, 110, WhatParamsToTest::All,
677 WhatOrdersToTest::All);
678 test_gemm<GemmWrapper>(context, 300, 400, 500,
679 WhatParamsToTest::OnlyGenericCase,
680 WhatOrdersToTest::All);
681 }
682
683 template <typename GemmWrapper>
test_gemv(typename GemmWrapper::Context * context)684 void test_gemv(typename GemmWrapper::Context* context) {
685 test_gemm<GemmWrapper>(context, 2, 2, 1, WhatParamsToTest::All,
686 WhatOrdersToTest::OnlyRCC);
687 test_gemm<GemmWrapper>(context, 3, 3, 1, WhatParamsToTest::All,
688 WhatOrdersToTest::OnlyRCC);
689 test_gemm<GemmWrapper>(context, 4, 4, 1, WhatParamsToTest::All,
690 WhatOrdersToTest::OnlyRCC);
691 test_gemm<GemmWrapper>(context, 5, 5, 1, WhatParamsToTest::All,
692 WhatOrdersToTest::OnlyRCC);
693 test_gemm<GemmWrapper>(context, 6, 6, 1, WhatParamsToTest::All,
694 WhatOrdersToTest::OnlyRCC);
695 test_gemm<GemmWrapper>(context, 3, 5, 1, WhatParamsToTest::All,
696 WhatOrdersToTest::OnlyRCC);
697 test_gemm<GemmWrapper>(context, 7, 3, 1, WhatParamsToTest::All,
698 WhatOrdersToTest::OnlyRCC);
699 test_gemm<GemmWrapper>(context, 5, 7, 1, WhatParamsToTest::All,
700 WhatOrdersToTest::OnlyRCC);
701 test_gemm<GemmWrapper>(context, 8, 8, 1, WhatParamsToTest::All,
702 WhatOrdersToTest::OnlyRCC);
703 test_gemm<GemmWrapper>(context, 32, 32, 1, WhatParamsToTest::All,
704 WhatOrdersToTest::OnlyRCC);
705 test_gemm<GemmWrapper>(context, 128, 128, 1, WhatParamsToTest::All,
706 WhatOrdersToTest::OnlyRCC);
707 test_gemm<GemmWrapper>(context, 321, 123, 1, WhatParamsToTest::All,
708 WhatOrdersToTest::OnlyRCC);
709
710 // Test all storage orders
711 test_gemm<GemmWrapper>(context, 70, 90, 1, WhatParamsToTest::All,
712 WhatOrdersToTest::All);
713 test_gemm<GemmWrapper>(context, 300, 400, 1,
714 WhatParamsToTest::OnlyGenericCase,
715 WhatOrdersToTest::All);
716 }
717
GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b)718 const char* GetBitDepthName(eight_bit_int_gemm::BitDepthSetting b) {
719 switch (b) {
720 case eight_bit_int_gemm::BitDepthSetting::A8B8:
721 return "Lhs: 8 bit, Rhs: 8 bit";
722 case eight_bit_int_gemm::BitDepthSetting::A5B7:
723 return "(legacy, no longer requantizing) Lhs: 7 bit, Rhs: 5 bit";
724 default:
725 abort();
726 return nullptr;
727 }
728 }
729
730 // Runs a small set of hand-picked data for per-channel quantized data.
731 // This test case comes from a set of 2 2x2 convolution filters run over a 3x3
732 // image.
TestWithSmallDataPerChannelQuantization()733 void TestWithSmallDataPerChannelQuantization() {
734 const int m = 2;
735 const int n = 9;
736 const int k = 12;
737
738 // 12 x 2, columnwise.
739 const std::uint8_t a_data[] = {0, 0, 0, 0, 0, 0, 0, 0,
740 0, 255, 255, 255, 64, 64, 64, 64,
741 64, 64, 0, 0, 0, 255, 255, 255};
742 const int lda = k;
743 int a_offset[] = {0, -64};
744 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
745 const OffsetColMap lhs_offset(a_offset, m);
746
747 // 12 x 9, columnwise.
748 const std::uint8_t b_data[] = {
749 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0,
750 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 127,
751 127, 127, 0, 0, 0, 127, 127, 127, 0, 0, 0, 255, 255, 255,
752 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0,
753 0, 0, 0, 0, 0, 0, 0, 127, 127, 127, 0, 0, 0, 127,
754 127, 127, 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127,
755 0, 0, 0, 0, 0, 0, 127, 127, 127, 127, 127, 127, 0, 0,
756 0, 127, 127, 127, 127, 127, 127, 127, 127, 127};
757 const int ldb = k;
758 int b_offset = -127;
759 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
760 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
761
762 // 2 x 9, columnwise.
763 const std::uint8_t expected_c_data[] = {255, 255, 0, 0, 127, 159,
764 0, 64, 0, 64, 127, 159,
765 127, 127, 127, 127, 127, 127};
766 const int ldc = m;
767 int c_offset[] = {97155, 97346};
768 int c_mult_int[] = {2741, 2741};
769 const int c_shift = 21;
770
771 const int c_count = m * n;
772 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
773 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
774 ldc);
775 const OffsetColMap result_offset(c_offset, m);
776 const OffsetColMap result_mult_int(c_mult_int, m);
777 const int result_shift = c_shift;
778
779 GemmContext gemm_context;
780 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
781 result_offset, result_mult_int, result_shift);
782 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
783 DefaultL8R8BitDepthParams>(
784 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
785 output_pipeline);
786
787 ResultStats stats;
788 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
789
790 ResultStatsBounds bounds;
791 const bool good = CheckResultStatsBounds(stats, bounds);
792 printf("TestWithSmallDataPerChannelQuantization: %s\n",
793 good ? "PASS" : "FAIL");
794 ReportResultStats(stats, bounds);
795 Check(good);
796 }
797
798 // Runs a larger set of hand-picked data for per-channel quantized data.
799 // This test case comes from a set of 22 3x3 convolution filters run over a 5x5
800 // image. Right now, I have 7 different filters and 15 copies of the first
801 // filter to make sure NEON code path that processes 16 rows at a time is
802 // covered.
TestWithLargeDataPerChannelQuantization()803 void TestWithLargeDataPerChannelQuantization() {
804 const int m = 22;
805 const int n = 25;
806 const int k = 27;
807
808 // 27 x 22, column-wise.
809 const std::uint8_t a_data[] = {
810 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
811 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
812 0, 0, 0, 0, 0, 0, 127, 127, 127, 255, 255, 255, 127, 127, 127,
813 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127, 127, 127,
814 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
815 127, 127, 127, 0, 0, 0, 51, 51, 51, 51, 51, 51, 51, 51, 51,
816 0, 0, 0, 255, 255, 255, 0, 0, 0, 51, 51, 51, 51, 51, 51,
817 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51, 51, 51, 51,
818 255, 255, 255, 51, 51, 51, 51, 51, 51, 0, 0, 0, 51, 51, 51,
819 0, 0, 0, 64, 64, 64, 0, 0, 0, 64, 64, 64, 255, 255, 255,
820 64, 64, 64, 0, 0, 0, 64, 64, 64, 0, 0, 0, 36, 36, 36,
821 0, 0, 0, 36, 36, 36, 0, 0, 0, 255, 255, 255, 0, 0, 0,
822 36, 36, 36, 0, 0, 0, 36, 36, 36, 0, 0, 0, 0, 0, 0,
823 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
824 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
825 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
826 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
827 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
828 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
829 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
830 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
831 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
832 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
833 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
834 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
835 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
836 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
837 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
838 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
839 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
840 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
841 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0,
842 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
843 0, 0, 0, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0,
844 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
845 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
846 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255,
847 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
848 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 0, 0, 0,
849 0, 0, 0, 0, 0, 0, 0, 0, 0,
850 };
851 const int lda = k;
852 int a_offset[] = {0, 0, 0, -51, -51, 0, -36, 0, 0, 0, 0,
853 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
854 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(a_data, m, k, lda);
855 const OffsetColMap lhs_offset(a_offset, m);
856
857 // 27 x 25, column-wise.
858 const std::uint8_t b_data[] = {
859 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 119, 119,
860 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
861 127, 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119,
862 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127,
863 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
864 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
865 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
866 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
867 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127,
868 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119,
869 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
870 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
871 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
872 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
873 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
874 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
875 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
876 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
877 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
878 119, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127,
879 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
880 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
881 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
882 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
883 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
884 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
885 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
886 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119,
887 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
888 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119, 119, 119,
889 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
890 119, 119, 119, 119, 136, 136, 136, 119, 119, 119, 119, 119, 119, 119,
891 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
892 136, 136, 136, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
893 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 136, 136, 136, 119,
894 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
895 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119,
896 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
897 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 119, 119, 119,
898 119, 119, 119, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127,
899 127, 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119,
900 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127,
901 127, 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119,
902 119, 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127,
903 127, 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 119, 119, 119,
904 119, 119, 119, 119, 119, 119, 119, 119, 119, 127, 127, 127, 127, 127,
905 127, 127, 127, 127, 119, 119, 119, 119, 119, 119, 127, 127, 127, 119,
906 119, 119, 119, 119, 119, 127, 127, 127, 127, 127, 127, 127, 127, 127,
907 127, 127, 127};
908 const int ldb = k;
909 int b_offset = -127;
910 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(b_data, k, n, ldb);
911 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
912
913 // 22 x 25, column-wise.
914 const std::uint8_t expected_c_data[] = {
915 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7,
916 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7,
917 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
918 7, 37, 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7,
919 7, 7, 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7,
920 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
921 37, 67, 67, 39, 79, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
922 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7,
923 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
924 87, 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
925 7, 7, 7, 7, 7, 7, 71, 87, 45, 41, 77, 7, 7, 7, 7,
926 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87,
927 87, 7, 103, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
928 7, 7, 7, 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7,
929 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87,
930 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
931 7, 7, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7,
932 7, 7, 7, 7, 7, 7, 7, 7, 7, 255, 135, 135, 255, 255, 143,
933 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
934 255, 7, 71, 7, 45, 87, 41, 77, 7, 7, 7, 7, 7, 7, 7,
935 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 7, 67, 87, 23, 91,
936 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
937 7, 37, 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7,
938 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7,
939 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
940 7, 71, 87, 45, 41, 77, 7, 7, 7, 7, 7, 7, 7, 7, 7,
941 7, 7, 7, 7, 7, 7, 7, 7, 7, 87, 87, 7, 103, 7, 7,
942 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
943 7, 67, 87, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
944 7, 7, 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7,
945 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37,
946 87, 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
947 7, 7, 7, 7, 7, 7, 37, 87, 67, 23, 91, 7, 7, 7, 7,
948 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 37, 87,
949 67, 23, 91, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
950 7, 7, 7, 7, 37, 37, 67, 67, 39, 79, 7, 7, 7, 7, 7,
951 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 99, 99, 99, 99, 99,
952 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99, 99,
953 99, 99, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
954 111, 111, 111, 111, 111, 111, 111, 111, 111,
955 };
956 const int ldc = m;
957 int c_offset[] = {
958 6477, 12954, 12954, 7793, 7793, 12954, 9282, 6477, 6477, 6477, 6477,
959 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477, 6477,
960 };
961 int c_mult_int[] = {
962 41121, 20560, 20560, 34267, 34267, 21937, 28784, 41121,
963 41121, 41121, 41121, 41121, 41121, 41121, 41121, 41121,
964 41121, 41121, 41121, 41121, 41121, 41121,
965 };
966 const int c_shift = 21;
967
968 const int c_count = m * n;
969 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
970 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
971 ldc);
972 const OffsetColMap result_offset(c_offset, m);
973 const OffsetColMap result_mult_int(c_mult_int, m);
974 const int result_shift = c_shift;
975
976 GemmContext gemm_context;
977 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
978 result_offset, result_mult_int, result_shift);
979 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
980 DefaultL8R8BitDepthParams>(
981 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
982 output_pipeline);
983
984 ResultStats stats;
985 GetResultStats(output_data.get(), expected_c_data, c_count, &stats);
986
987 ResultStatsBounds bounds;
988 const bool good = CheckResultStatsBounds(stats, bounds);
989 printf("TestWithLargeDataPerChannelQuantization: %s\n",
990 good ? "PASS" : "FAIL");
991 ReportResultStats(stats, bounds);
992 Check(good);
993 }
994
995 // Multithreading only activates when the result has more than 16 rows, and also
996 // (result rows) * (result cols) * depth >= 2 x 65 x 1024. Size was selected
997 // to run in 3 threads.
998 //
999 // Based on the following floating point data:
1000 // LHS: all zeros except 10.0, 20.0 at the beginning of first 16 rows;
1001 // 1.0, 2.0 at the beginning of next 16 rows; 0.1, 0.2 in next 16 rows;
1002 // 0.01, 0.02 in last 16 rows.
1003 // RHS: all zeros except 1.0 in (0, 0) and 2.0 in (1, 0).
1004 // Varying boundaries were used for each 16 rows block of LHS, to test for
1005 // correct indexing into offsets.
1006 // Expected result: all zeros, except 50.0 at the beginning of first 16 rows;
1007 // 5.0 at the beginning of next 16 rows; 0.5 in next 16 rows; 0.05 in last
1008 // 16 rows.
TestMultithreadedPerChannelQuantization()1009 void TestMultithreadedPerChannelQuantization() {
1010 const int m = 64;
1011 const int n = 20;
1012 const int k = 160;
1013
1014 // LHS, m x k.
1015 const std::array<std::int32_t, 4> lhs_offsets_terse{{
1016 0, -51, -85, -109,
1017 }};
1018 assert(lhs_offsets_terse.size() * 16 == m);
1019 const std::array<std::uint8_t, 4> lhs_first_el{{
1020 128, 153, 170, 182,
1021 }};
1022 assert(lhs_first_el.size() * 16 == m);
1023
1024 // lhs_first_el at (i, 0) and 255 at (i, 1), other values are all -offset.
1025 std::vector<std::uint8_t> a_data(m * k, 0);
1026 for (int i = 0; i < m; ++i) {
1027 a_data[i * k] = lhs_first_el[i / 16];
1028 a_data[i * k + 1] = 255;
1029 for (int j = 2; j < k; ++j) {
1030 a_data[i * k + j] = std::uint8_t(-lhs_offsets_terse[i / 16]);
1031 }
1032 }
1033
1034 const int lda = k;
1035 // Given values at [i / 16].
1036 std::vector<std::int32_t> a_offset(m, 0);
1037 for (int i = 0; i < m; ++i) {
1038 a_offset[i] = lhs_offsets_terse[i / 16];
1039 }
1040
1041 MatrixMap<const std::uint8_t, MapOrder::RowMajor> lhs(&a_data[0], m, k, lda);
1042 const OffsetColMap lhs_offset(&a_offset[0], m);
1043
1044 // RHS, k x n.
1045 // All zeros, except 128 at (0, 0) and 255 at (1, 0).
1046 std::vector<std::uint8_t> b_data(k * n, 0);
1047 b_data[0] = 128;
1048 b_data[1] = 255;
1049
1050 const int ldb = k;
1051 std::int32_t b_offset = 0;
1052 MatrixMap<const std::uint8_t, MapOrder::ColMajor> rhs(&b_data[0], k, n, ldb);
1053 const OffsetRowDup rhs_offset(b_offset, rhs.cols());
1054
1055 // Result, m x n.
1056 // All zeros, except given values at (i / 16, 0).
1057 const std::array<std::uint8_t, 4> expected_c_terse{{
1058 142, 159, 182, 213,
1059 }};
1060 assert(expected_c_terse.size() * 16 == m);
1061 std::vector<std::uint8_t> expected_c_data(m * n, 0);
1062 for (int i = 0; i < m; ++i) {
1063 expected_c_data[i] = expected_c_terse[i / 16];
1064 }
1065
1066 const int ldc = m;
1067 // All zeros.
1068 std::vector<std::int32_t> c_offset(m, 0);
1069 // Given values at [i / 16].
1070 const std::array<std::int32_t, 4> c_mult_int_terse{{
1071 3655, 5140, 7049, 9595,
1072 }};
1073 assert(c_mult_int_terse.size() * 16 == m);
1074 std::vector<std::int32_t> c_mult_int(m);
1075 for (int i = 0; i < m; ++i) {
1076 c_mult_int[i] = c_mult_int_terse[i / 16];
1077 }
1078
1079 const int c_shift = 21;
1080
1081 const int c_count = m * n;
1082 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1083 MatrixMap<std::uint8_t, MapOrder::ColMajor> result(output_data.get(), m, n,
1084 ldc);
1085 const OffsetColMap result_offset(&c_offset[0], m);
1086 const OffsetColMap result_mult_int(&c_mult_int[0], m);
1087 const int result_shift = c_shift;
1088
1089 GemmContext gemm_context;
1090 auto output_pipeline = MakeStandardOutputPipeline<VectorShape::Col>(
1091 result_offset, result_mult_int, result_shift);
1092 GemmWithOutputPipelinePC<std::uint8_t, std::uint8_t,
1093 DefaultL8R8BitDepthParams>(
1094 &gemm_context, lhs, rhs, &result, lhs_offset, rhs_offset,
1095 output_pipeline);
1096
1097 ResultStats stats;
1098 GetResultStats(output_data.get(), &expected_c_data[0], c_count, &stats);
1099
1100 ResultStatsBounds bounds;
1101 const bool good = CheckResultStatsBounds(stats, bounds);
1102 printf("TestMultithreadedPerChannelQuantization: %s\n",
1103 good ? "PASS" : "FAIL");
1104 ReportResultStats(stats, bounds);
1105 Check(good);
1106 }
1107
1108 // Runs a small set of hand-calculated data through the implementation.
TestWithSmallData()1109 void TestWithSmallData() {
1110 const int m = 4;
1111 const int n = 2;
1112 const int k = 3;
1113 // Matrix A (LHS) is:
1114 // | 7 | 10 | 13 | 16 |
1115 // | 8 | 11 | 14 | 17 |
1116 // | 9 | 12 | 15 | 18 |
1117 const std::uint8_t a_data[] = {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18};
1118 // Matrix B (RHS) is:
1119 // | 1 | 3 | 5 |
1120 // | 2 | 4 | 6 |
1121 const std::uint8_t b_data[] = {1, 2, 3, 4, 5, 6};
1122 // Here are the results we expect, from hand calculations:
1123 // (1 * 7) + (3 * 8) + (5 * 9) = 76
1124 // (2 * 7) + (4 * 8) + (6 * 9) = 100
1125 // (1 * 10) + (3 * 11) + (5 * 12) = 103
1126 // (2 * 10) + (4 * 11) + (6 * 12) = 136
1127 // (1 * 13) + (3 * 14) + (5 * 15) = 130
1128 // (2 * 13) + (4 * 14) + (6 * 15) = 172
1129 // (1 * 16) + (3 * 17) + (5 * 18) = 157
1130 // (2 * 16) + (4 * 17) + (6 * 18) = 208
1131 // That means matrix C should be:
1132 // | 76 | 103 | 130 | 157 |
1133 // | 100 | 136 | 172 | 208 |
1134 const std::uint8_t expected_data[] = {76, 100, 103, 136, 130, 172, 157, 208};
1135
1136 const int c_count = m * n;
1137 std::unique_ptr<std::uint8_t[]> output_data(new std::uint8_t[c_count]);
1138
1139 const bool is_a_transposed = true;
1140 const bool is_b_transposed = true;
1141 const bool is_c_transposed = true;
1142 const int lda = k;
1143 const int ldb = n;
1144 const int ldc = n;
1145
1146 const int a_offset = 0;
1147 const int b_offset = 0;
1148 const int c_offset = 0;
1149 const int c_mult = 1;
1150 const int c_shift = 0;
1151
1152 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1153 is_a_transposed, is_b_transposed, is_c_transposed, m, n, k, a_data,
1154 a_offset, lda, b_data, b_offset, ldb, output_data.get(), c_offset, c_mult,
1155 c_shift, ldc, eight_bit_int_gemm::BitDepthSetting::A8B8);
1156
1157 ResultStats stats;
1158 GetResultStats(output_data.get(), expected_data, c_count, &stats);
1159
1160 ResultStatsBounds bounds;
1161 const bool good = CheckResultStatsBounds(stats, bounds);
1162 printf("TestWithSmallData: %s\n", good ? "PASS" : "FAIL");
1163 ReportResultStats(stats, bounds);
1164 Check(good);
1165 }
1166
1167 // This is the most realistic test of how we'll be using the low-precision GEMM
1168 // function in applications. It takes in large input matrices that have been
1169 // captured from an actual neural network run.
TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,int tolerance_median,int tolerance_max)1170 void TestWithRealData(eight_bit_int_gemm::BitDepthSetting BitDepth,
1171 int tolerance_median, int tolerance_max) {
1172 std::unique_ptr<std::uint8_t[]> output_data(
1173 new std::uint8_t[test_data::c_count]);
1174 gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
1175 test_data::is_a_transposed, test_data::is_b_transposed,
1176 test_data::is_c_transposed, test_data::m, test_data::n, test_data::k,
1177 test_data::a_data, test_data::a_offset, test_data::k, test_data::b_data,
1178 test_data::b_offset, test_data::k, output_data.get(), test_data::c_offset,
1179 test_data::c_mult_int, test_data::c_shift, test_data::m, BitDepth);
1180
1181 ResultStats stats;
1182 GetResultStats(output_data.get(), test_data::expected_c_data,
1183 test_data::c_count, &stats);
1184
1185 ResultStatsBounds bounds;
1186 if (BitDepth == eight_bit_int_gemm::BitDepthSetting::A5B7) {
1187 bounds.med_unsigned_diff = tolerance_median;
1188 bounds.max_unsigned_diff = tolerance_max;
1189 bounds.med_signed_diff = 0;
1190 bounds.mean_signed_diff = 0.2f;
1191 }
1192
1193 const bool good = CheckResultStatsBounds(stats, bounds);
1194 printf("TestWithRealData: %s with %s\n", good ? "PASS" : "FAIL",
1195 GetBitDepthName(BitDepth));
1196 ReportResultStats(stats, bounds);
1197 Check(good);
1198 }
1199
1200 template <typename BitDepthParams, MapOrder ResultOrder>
TestOutputStages(int rows,int depth,int cols,int result_offset,int result_mult_int,int result_shift)1201 void TestOutputStages(int rows, int depth, int cols, int result_offset,
1202 int result_mult_int, int result_shift) {
1203 Matrix<std::uint8_t, MapOrder::RowMajor> lhs(rows, depth);
1204 Matrix<std::uint8_t, MapOrder::ColMajor> rhs(depth, cols);
1205 Matrix<std::int32_t, ResultOrder> result_raw_int32(rows, cols);
1206 MakeRandom<typename BitDepthParams::LhsRange>(&lhs);
1207 MakeRandom<typename BitDepthParams::RhsRange>(&rhs);
1208 const int lhs_offset = 12;
1209 const int rhs_offset = -34;
1210
1211 // Test an empty pipeline, i.e. returning raw int32 accumulators.
1212 auto empty_pipeline = std::make_tuple();
1213 GemmContext context;
1214 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1215 &context, lhs.const_map(), rhs.const_map(), &result_raw_int32, lhs_offset,
1216 rhs_offset, empty_pipeline);
1217
1218 for (int r = 0; r < rows; r++) {
1219 for (int c = 0; c < cols; c++) {
1220 std::int32_t expected = 0;
1221 for (int d = 0; d < depth; d++) {
1222 std::int32_t lhs_val =
1223 static_cast<std::int32_t>(lhs(r, d)) + lhs_offset;
1224 std::int32_t rhs_val =
1225 static_cast<std::int32_t>(rhs(d, c)) + rhs_offset;
1226 expected += lhs_val * rhs_val;
1227 }
1228 Check(expected == result_raw_int32(r, c));
1229 }
1230 }
1231
1232 // Test a pipeline with only the quantize-down stage, still returning
1233 // unclamped (but scaled) int32's
1234 OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
1235 quantize_down_stage.result_offset = result_offset;
1236 quantize_down_stage.result_mult_int = result_mult_int;
1237 quantize_down_stage.result_shift = result_shift;
1238 auto quantize_down_pipeline = std::make_tuple(quantize_down_stage);
1239 Matrix<std::int32_t, ResultOrder> result_quantized_down_int32(rows, cols);
1240 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1241 &context, lhs.const_map(), rhs.const_map(), &result_quantized_down_int32,
1242 lhs_offset, rhs_offset, quantize_down_pipeline);
1243
1244 std::int64_t sum = 0;
1245 for (int r = 0; r < rows; r++) {
1246 for (int c = 0; c < cols; c++) {
1247 std::int32_t raw = result_raw_int32(r, c);
1248 std::int32_t expected = RoundingDivideByPOT(
1249 (raw + result_offset) * result_mult_int, result_shift);
1250 Check(expected == result_quantized_down_int32(r, c));
1251 sum += expected;
1252 }
1253 }
1254 std::int64_t avg = sum / (rows * cols);
1255 // Test that the average quantized-down value falls reasonably in the
1256 // middle of the [0..255] range. Otherwise, the multiplier / shift need to be
1257 // adjusted.
1258 Check(avg >= 64 && avg <= 192);
1259
1260 // Test the familiar default pipeline consisting of quantize-down and
1261 // clamp-and-cast-to-uint8.
1262 OutputStageSaturatingCastToUint8 saturating_cast_stage;
1263 auto quantize_down_and_saturating_cast_pipeline =
1264 std::make_tuple(quantize_down_stage, saturating_cast_stage);
1265 Matrix<std::uint8_t, ResultOrder> result_quantized_down_saturated_uint8(rows,
1266 cols);
1267 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1268 &context, lhs.const_map(), rhs.const_map(),
1269 &result_quantized_down_saturated_uint8, lhs_offset, rhs_offset,
1270 quantize_down_and_saturating_cast_pipeline);
1271
1272 for (int r = 0; r < rows; r++) {
1273 for (int c = 0; c < cols; c++) {
1274 std::int32_t quantized = result_quantized_down_int32(r, c);
1275 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1276 Check(expected == result_quantized_down_saturated_uint8(r, c));
1277 }
1278 }
1279
1280 // Test a bias-addition with row-vector
1281 std::vector<std::int32_t> row_vector_data(cols);
1282 std::uniform_int_distribution<std::int32_t> uniform_minus_500_plus_500(-500,
1283 500);
1284 for (int i = 0; i < cols; i++) {
1285 row_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1286 }
1287 typedef VectorMap<std::int32_t, VectorShape::Row> RowVectorMap;
1288 RowVectorMap row_vector_map(row_vector_data.data(), cols);
1289 OutputStageBiasAddition<RowVectorMap> row_bias_addition_stage;
1290 row_bias_addition_stage.bias_vector = row_vector_map;
1291 auto row_bias_addition_pipeline = std::make_tuple(row_bias_addition_stage);
1292 Matrix<std::int32_t, ResultOrder> result_of_row_bias_addition(rows, cols);
1293 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1294 &context, lhs.const_map(), rhs.const_map(), &result_of_row_bias_addition,
1295 lhs_offset, rhs_offset, row_bias_addition_pipeline);
1296 for (int r = 0; r < rows; r++) {
1297 for (int c = 0; c < cols; c++) {
1298 std::int32_t expected = result_raw_int32(r, c) + row_vector_data[c];
1299 Check(expected == result_of_row_bias_addition(r, c));
1300 }
1301 }
1302
1303 // Test a bias-addition with column-vector
1304 std::vector<std::int32_t> col_vector_data(rows);
1305 for (int i = 0; i < rows; i++) {
1306 col_vector_data[i] = uniform_minus_500_plus_500(RandomEngine());
1307 }
1308 typedef VectorMap<std::int32_t, VectorShape::Col> ColVectorMap;
1309 ColVectorMap col_vector_map(col_vector_data.data(), rows);
1310 OutputStageBiasAddition<ColVectorMap> col_bias_addition_stage;
1311 col_bias_addition_stage.bias_vector = col_vector_map;
1312 auto col_bias_addition_pipeline = std::make_tuple(col_bias_addition_stage);
1313 Matrix<std::int32_t, ResultOrder> result_of_col_bias_addition(rows, cols);
1314 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1315 &context, lhs.const_map(), rhs.const_map(), &result_of_col_bias_addition,
1316 lhs_offset, rhs_offset, col_bias_addition_pipeline);
1317 for (int r = 0; r < rows; r++) {
1318 for (int c = 0; c < cols; c++) {
1319 std::int32_t expected = result_raw_int32(r, c) + col_vector_data[r];
1320 Check(expected == result_of_col_bias_addition(r, c));
1321 }
1322 }
1323
1324 // Test a clamp
1325 OutputStageClamp clamp_stage;
1326 // Determine min and max of raw int32 accumulators
1327 std::int32_t raw_min = std::numeric_limits<std::int32_t>::max();
1328 std::int32_t raw_max = std::numeric_limits<std::int32_t>::min();
1329 for (int r = 0; r < rows; r++) {
1330 for (int c = 0; c < cols; c++) {
1331 raw_min = std::min(raw_min, result_raw_int32(r, c));
1332 raw_max = std::max(raw_max, result_raw_int32(r, c));
1333 }
1334 }
1335 // Pick some interesting clamp min/max bounds
1336 clamp_stage.min = static_cast<std::int32_t>(raw_min * 0.7 + raw_max * 0.3);
1337 clamp_stage.max = static_cast<std::int32_t>(raw_min * 0.3 + raw_max * 0.7);
1338 assert(raw_min <= clamp_stage.min && clamp_stage.min <= clamp_stage.max &&
1339 clamp_stage.max <= raw_max);
1340 auto clamp_pipeline = std::make_tuple(clamp_stage);
1341 Matrix<std::int32_t, ResultOrder> result_clamped(rows, cols);
1342 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1343 &context, lhs.const_map(), rhs.const_map(), &result_clamped, lhs_offset,
1344 rhs_offset, clamp_pipeline);
1345 for (int r = 0; r < rows; r++) {
1346 for (int c = 0; c < cols; c++) {
1347 std::int32_t raw = result_raw_int32(r, c);
1348 std::int32_t expected =
1349 std::min(std::max(raw, clamp_stage.min), clamp_stage.max);
1350 Check(expected == result_clamped(r, c));
1351 }
1352 }
1353
1354 // Test tanh
1355 OutputStageTanh tanh_stage;
1356 const std::int32_t real_zero_as_int32 = (raw_max + raw_min) / 2;
1357 const std::int32_t real_amplitude_as_int32 = (raw_max - raw_min) / 16;
1358 tanh_stage.real_zero_as_int32 = real_zero_as_int32;
1359 tanh_stage.real_amplitude_as_int32 = real_amplitude_as_int32;
1360 auto tanh_pipeline = std::make_tuple(tanh_stage);
1361 Matrix<std::int32_t, ResultOrder> result_tanh(rows, cols);
1362 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1363 &context, lhs.const_map(), rhs.const_map(), &result_tanh, lhs_offset,
1364 rhs_offset, tanh_pipeline);
1365 for (int r = 0; r < rows; r++) {
1366 for (int c = 0; c < cols; c++) {
1367 std::int32_t raw = result_raw_int32(r, c);
1368 double real_input =
1369 double(raw - real_zero_as_int32) / real_amplitude_as_int32;
1370 double expected = std::tanh(real_input);
1371 std::int32_t actual_int32 = result_tanh(r, c);
1372 double actual =
1373 double(actual_int32 - real_zero_as_int32) / real_amplitude_as_int32;
1374 Check(std::abs(expected - actual) < 2e-4);
1375 }
1376 }
1377
1378 // Test a pipeline with bias and clamp
1379 auto bias_clamp_pipeline =
1380 std::make_tuple(col_bias_addition_stage, clamp_stage);
1381 Matrix<std::int32_t, ResultOrder> result_biased_clamped(rows, cols);
1382 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1383 &context, lhs.const_map(), rhs.const_map(), &result_biased_clamped,
1384 lhs_offset, rhs_offset, bias_clamp_pipeline);
1385 for (int r = 0; r < rows; r++) {
1386 for (int c = 0; c < cols; c++) {
1387 std::int32_t raw = result_raw_int32(r, c);
1388 std::int32_t biased = raw + col_vector_data[r];
1389 std::int32_t expected =
1390 std::min(std::max(biased, clamp_stage.min), clamp_stage.max);
1391 Check(expected == result_biased_clamped(r, c));
1392 }
1393 }
1394
1395 // Test a full pipeline with bias and clamp and quantization down to 8bit
1396 // result
1397 auto bias_clamp_quantize_cast_pipeline =
1398 std::make_tuple(col_bias_addition_stage, clamp_stage, quantize_down_stage,
1399 saturating_cast_stage);
1400 Matrix<std::uint8_t, ResultOrder> result_biased_clamped_quantized_casted(
1401 rows, cols);
1402 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1403 &context, lhs.const_map(), rhs.const_map(),
1404 &result_biased_clamped_quantized_casted, lhs_offset, rhs_offset,
1405 bias_clamp_quantize_cast_pipeline);
1406 for (int r = 0; r < rows; r++) {
1407 for (int c = 0; c < cols; c++) {
1408 std::int32_t quantized = RoundingDivideByPOT(
1409 (result_biased_clamped(r, c) + result_offset) * result_mult_int,
1410 result_shift);
1411 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1412 Check(expected == result_biased_clamped_quantized_casted(r, c));
1413 }
1414 }
1415
1416 // Test a pipeline with the fixed-point-multiplier variant stage for the
1417 // quantizing down of 32bit accumulators.
1418 //
1419 // First, figure appropriate fixedpoint multiplier and shift values.
1420 std::int32_t result_fixedpoint_multiplier = result_mult_int;
1421 std::int32_t result_fixedpoint_shift = result_shift;
1422 Check(result_mult_int > 0);
1423 Check(result_shift > 0);
1424 result_fixedpoint_multiplier = result_mult_int;
1425 result_fixedpoint_shift = result_shift - 31;
1426 while (result_fixedpoint_multiplier < (1 << 30)) {
1427 result_fixedpoint_multiplier <<= 1;
1428 result_fixedpoint_shift++;
1429 }
1430 Check(result_fixedpoint_shift >= 0);
1431 // Now test OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1432 OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
1433 quantize_down_by_fixedpoint_stage;
1434 quantize_down_by_fixedpoint_stage.result_offset_after_shift =
1435 static_cast<std::int32_t>(
1436 round(static_cast<double>(result_offset * result_mult_int) /
1437 (1 << result_shift)));
1438 quantize_down_by_fixedpoint_stage.result_fixedpoint_multiplier =
1439 result_fixedpoint_multiplier;
1440 quantize_down_by_fixedpoint_stage.result_shift = result_fixedpoint_shift;
1441 auto quantize_down_by_fixedpoint_pipeline =
1442 std::make_tuple(quantize_down_by_fixedpoint_stage);
1443 Matrix<std::int32_t, ResultOrder> result_quantized_down_by_fixedpoint_int32(
1444 rows, cols);
1445 GemmWithOutputPipeline<std::uint8_t, std::int32_t, DefaultL8R8BitDepthParams>(
1446 &context, lhs.const_map(), rhs.const_map(),
1447 &result_quantized_down_by_fixedpoint_int32, lhs_offset, rhs_offset,
1448 quantize_down_by_fixedpoint_pipeline);
1449
1450 std::vector<std::int32_t> diffs_caused_by_fixedpoint;
1451 for (int r = 0; r < rows; r++) {
1452 for (int c = 0; c < cols; c++) {
1453 const std::int32_t actual =
1454 result_quantized_down_by_fixedpoint_int32(r, c);
1455 const std::int32_t raw = result_raw_int32(r, c);
1456 const std::int32_t expected =
1457 quantize_down_by_fixedpoint_stage.result_offset_after_shift +
1458 RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(
1459 raw, result_fixedpoint_multiplier),
1460 result_fixedpoint_shift);
1461 Check(actual == expected);
1462 }
1463 }
1464
1465 // Test the variant of the familiar default pipeline consisting of
1466 // quantize-down and
1467 // clamp-and-cast-to-uint8, where we used fixedpoint multipliers for the
1468 // downscaling.
1469 auto quantize_down_by_fixedpoint_and_saturating_cast_pipeline =
1470 std::make_tuple(quantize_down_by_fixedpoint_stage, saturating_cast_stage);
1471 Matrix<std::uint8_t, ResultOrder>
1472 result_quantized_down_by_fixedpoint_saturated_uint8(rows, cols);
1473 GemmWithOutputPipeline<std::uint8_t, std::uint8_t, DefaultL8R8BitDepthParams>(
1474 &context, lhs.const_map(), rhs.const_map(),
1475 &result_quantized_down_by_fixedpoint_saturated_uint8, lhs_offset,
1476 rhs_offset, quantize_down_by_fixedpoint_and_saturating_cast_pipeline);
1477
1478 for (int r = 0; r < rows; r++) {
1479 for (int c = 0; c < cols; c++) {
1480 std::int32_t quantized = result_quantized_down_by_fixedpoint_int32(r, c);
1481 std::uint8_t expected = std::min(std::max(quantized, 0), 255);
1482 Check(expected ==
1483 result_quantized_down_by_fixedpoint_saturated_uint8(r, c));
1484 }
1485 }
1486
1487 printf("TestOutputStages: PASS with ResultOrder=%s\n",
1488 OrderName(ResultOrder));
1489 }
1490
1491 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1492 template <typename BitDepthParams>
TestExhaustively()1493 void TestExhaustively() {
1494 GemmContext context;
1495
1496 // Test the internal GEMM interfaces
1497 test_gemm<
1498 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1499 std::uint8_t, BitDepthParams>>(&context);
1500
1501 test_gemm<
1502 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1503 std::uint8_t, BitDepthParams>>(&context);
1504
1505 // Test the public GEMM interfaces
1506 test_gemm<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1507
1508 // Test GEMV cases (internal interfaces)
1509 test_gemv<
1510 SingleThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1511 std::uint8_t, BitDepthParams>>(&context);
1512
1513 test_gemv<
1514 MultiThreadGemmWrapper<DefaultKernel<BitDepthParams>,
1515 std::uint8_t, BitDepthParams>>(&context);
1516
1517 // Test GEMV cases (public interfaces)
1518 test_gemv<PublicGemmWrapper<std::uint8_t, BitDepthParams>>(&context);
1519 }
1520
1521 template <eight_bit_int_gemm::BitDepthSetting BitDepthSetting>
TestExhaustivelyEightBitIntGemm()1522 void TestExhaustivelyEightBitIntGemm() {
1523 GemmContext context;
1524 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1525 test_gemv<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1526 test_gemm<EightBitIntGemmWrapper<std::uint8_t, BitDepthSetting>>(&context);
1527 }
1528
TestKernels()1529 void TestKernels() {
1530 GemmContext context;
1531
1532 // Test specific kernels with various different formats,
1533 // to exercises corner cases especially in the packing code.
1534 test_gemm_kernel<
1535 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<1, 1>, 1>,
1536 KernelSideFormat<CellFormat<1, 1>, 1>>>>(
1537 &context);
1538
1539 test_gemm_kernel<
1540 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 1>,
1541 KernelSideFormat<CellFormat<4, 2>, 2>>>>(
1542 &context);
1543
1544 test_gemm_kernel<
1545 ReferenceKernel<KernelFormat<KernelSideFormat<CellFormat<4, 2>, 4>,
1546 KernelSideFormat<CellFormat<4, 2>, 5>>>>(
1547 &context);
1548
1549 test_gemm_kernel<ReferenceKernel<KernelFormat<
1550 KernelSideFormat<CellFormat<3, 4, CellOrder::DepthMajor>, 2>,
1551 KernelSideFormat<CellFormat<5, 4, CellOrder::DepthMajor>, 3>>>>(&context);
1552
1553 test_gemm_kernel<ReferenceKernel<KernelFormat<
1554 KernelSideFormat<CellFormat<3, 4, CellOrder::WidthMajor>, 2>,
1555 KernelSideFormat<CellFormat<5, 4, CellOrder::WidthMajor>, 3>>>>(&context);
1556
1557 test_gemm_kernel<ReferenceKernel<KernelFormat<
1558 KernelSideFormat<CellFormat<5, 2, CellOrder::WidthMajor>, 3>,
1559 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2>>>>(&context);
1560
1561 test_gemm_kernel<ReferenceKernel<KernelFormat<
1562 KernelSideFormat<CellFormat<5, 2, CellOrder::DepthMajor>, 3>,
1563 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2>>>>(&context);
1564
1565 test_gemm_kernel<ReferenceKernel<KernelFormat<
1566 KernelSideFormat<CellFormat<8, 8, CellOrder::Diagonal>, 2>,
1567 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>>>>(&context);
1568
1569 test_gemm_kernel<ReferenceKernel<KernelFormat<
1570 KernelSideFormat<CellFormat<1, 4, CellOrder::DepthMajor>, 1>,
1571 KernelSideFormat<CellFormat<4, 4, CellOrder::Diagonal>, 1>>>>(&context);
1572 }
1573
1574 #endif // not GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1575
1576 template <typename BitDepthParams>
TestOutputStages()1577 void TestOutputStages() {
1578 // Test non-default output pipelines with various combinations of
1579 // output stages.
1580 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(63, 10, 127, 5, 17, 14);
1581 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(63, 10, 127, 5, 17, 14);
1582 TestOutputStages<BitDepthParams, MapOrder::RowMajor>(630, 10, 1270, 5, 17,
1583 14);
1584 TestOutputStages<BitDepthParams, MapOrder::ColMajor>(630, 10, 1270, 5, 17,
1585 14);
1586 }
1587
test()1588 void test() {
1589 #ifdef GEMMLOWP_TEST_PROFILE
1590 RegisterCurrentThreadForProfiling();
1591 StartProfiling();
1592 #endif
1593
1594 // Run a first quick test against hand-calculated data.
1595 TestWithSmallData();
1596
1597 #ifndef GEMMLOWP_SKIP_EXHAUSTIVE_TESTS
1598 TestExhaustively<DefaultL8R8BitDepthParams>();
1599 TestExhaustively<L8R8WithLhsNonzeroBitDepthParams>();
1600 TestExhaustively<DefaultL7R5BitDepthParams>(); // legacy, same as L8R8
1601 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A8B8>();
1602 TestExhaustivelyEightBitIntGemm<eight_bit_int_gemm::BitDepthSetting::A5B7>();
1603 TestKernels();
1604 #endif
1605
1606 // Run against actual data from a network evaluation.
1607 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A8B8, 0, 0);
1608 TestWithRealData(eight_bit_int_gemm::BitDepthSetting::A5B7, 2, 10);
1609
1610 // Test non-default output pipelines with various combinations of
1611 // output stages.
1612 TestOutputStages<DefaultL8R8BitDepthParams>();
1613 TestOutputStages<L8R8WithLhsNonzeroBitDepthParams>();
1614
1615 // Test per channel quantization.
1616 TestWithSmallDataPerChannelQuantization();
1617 TestWithLargeDataPerChannelQuantization();
1618 TestMultithreadedPerChannelQuantization();
1619 #ifdef GEMMLOWP_TEST_PROFILE
1620 FinishProfiling();
1621 #endif
1622
1623 std::cerr << "All tests passed." << std::endl;
1624
1625 // We have been testing the eight_bit_int_gemm, so we should free its
1626 // persistent
1627 // resources now to avoid having leak-checking tools report leaks.
1628 eight_bit_int_gemm::FreePersistentResources();
1629 }
1630
1631 } // end namespace gemmlowp
1632
1633 // For iOS, we need to define our own main(), so skip it here.
1634 #if !(defined(__APPLE__) && (TARGET_OS_IPHONE || TARGET_IPHONE_SIMULATOR))
main()1635 int main() { gemmlowp::test(); }
1636 #endif
1637