1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <algorithm>
10 #include <cfloat>
11 #include <chrono>
12 #include <cmath>
13 #include <functional>
14 #include <iostream>
15 #include <random>
16 #include <vector>
17
18 #include <pack_block_sparse.h>
19 #include <qnnpack/AlignedAllocator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/q8gemm.h>
23 #include <qnnpack/q8gemm_sparse.h>
24 #include <qnnpack/requantization.h>
25
26 #include <benchmark/benchmark.h>
27
28 namespace {
divideRoundUp(uint32_t x,uint32_t q)29 inline uint32_t divideRoundUp(uint32_t x, uint32_t q) {
30 return x / q + uint32_t(x % q != 0);
31 }
32
roundUp(uint32_t x,uint32_t q)33 inline uint32_t roundUp(uint32_t x, uint32_t q) {
34 return q * divideRoundUp(x, q);
35 }
36
fillBlockSparseWeights(uint8_t * b,size_t N,size_t K,size_t row_block_size,size_t col_block_size,float sparsity,const uint8_t * zero_points)37 void fillBlockSparseWeights(
38 uint8_t* b,
39 size_t N,
40 size_t K,
41 size_t row_block_size,
42 size_t col_block_size,
43 float sparsity,
44 const uint8_t* zero_points) {
45 std::random_device randomDevice;
46 auto rng = std::mt19937(randomDevice());
47 std::bernoulli_distribution dist{sparsity};
48 for (uint32_t n = 0; n < N ; n += row_block_size) {
49 for (uint32_t k = 0; k < K; k += col_block_size) {
50 if (dist(rng)) {
51 for (uint32_t nb = 0; (nb < row_block_size) && (n + nb < N); ++nb) {
52 for (uint32_t kb = 0; (kb < col_block_size) && (k + kb < K); ++kb) {
53 *(b + (n + nb) * K + k + kb) = zero_points[n + nb];
54 }
55 }
56 }
57 }
58 }
59 }
60
61 }
62
63 class Q8GEMM : public benchmark::Fixture {
64 public:
Q8GEMM(uint32_t mr,uint32_t nr,uint32_t np,uint32_t kr)65 inline Q8GEMM(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr)
66 : mr_(mr), nr_(nr), np_(np), kr_(kr), mc_(mr), nc_(nr), kc_(kr) {}
67
SetUp(const benchmark::State &)68 void SetUp(const benchmark::State&) override {
69 std::random_device randomDevice;
70 auto rng = std::mt19937(randomDevice());
71 auto s32rng =
72 std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
73 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
74
75 a_.resize(mc() * kc());
76 std::generate(a_.begin(), a_.end(), std::ref(u8rng));
77 k_.resize(nc() * kc());
78 std::generate(k_.begin(), k_.end(), std::ref(u8rng));
79 b_.resize(nc());
80 std::generate(b_.begin(), b_.end(), std::ref(s32rng));
81 w_.resize(
82 kcStride() * ncStride() +
83 ncStride() * sizeof(int32_t) / sizeof(uint8_t));
84 std::fill(w_.begin(), w_.end(), 127);
85 size_t num_zero_points_kernel = (nc_ + (nr_ -1)) & -nr_;
86 std::vector<uint8_t> kernel_zero_points(num_zero_points_kernel, 127);
87 std::vector<float> requantization_scales(num_zero_points_kernel, 0.75f);
88 pytorch_pack_q8gemm_w(
89 nc(),
90 kc(),
91 nr(),
92 np(),
93 kr(),
94 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
95 127,
96 127,
97 #endif
98 k(),
99 b(),
100 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
101 kernel_zero_points.data(),
102 #endif
103 w());
104 c_.resize(mc() * nc());
105 std::fill(c_.begin(), c_.end(), 0xA5);
106
107 quantizationParams_ = pytorch_qnnp_compute_conv_quantization_params(
108 127, kernel_zero_points.data(),
109 requantization_scales.data(), 127, 1, 254);
110 }
111
TearDown(benchmark::State & state)112 void TearDown(benchmark::State& state) override {
113 state.SetItemsProcessed(
114 uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
115 a_.clear();
116 k_.clear();
117 b_.clear();
118 w_.clear();
119 c_.clear();
120 }
121
a() const122 inline const uint8_t* a() const {
123 return a_.data();
124 }
125
k() const126 inline const uint8_t* k() const {
127 return k_.data();
128 }
129
b() const130 inline const int32_t* b() const {
131 return b_.data();
132 }
133
w()134 inline uint8_t* w() {
135 return w_.data();
136 }
137
w() const138 inline const uint8_t* w() const {
139 return w_.data();
140 }
141
c()142 inline uint8_t* c() {
143 return c_.data();
144 }
145
mr() const146 inline uint32_t mr() const {
147 return mr_;
148 }
149
mc() const150 inline uint32_t mc() const {
151 return mc_;
152 }
153
nr() const154 inline uint32_t nr() const {
155 return nr_;
156 }
157
np() const158 inline uint32_t np() const {
159 return np_;
160 }
161
nc() const162 inline uint32_t nc() const {
163 return nc_;
164 }
165
ncStride() const166 inline uint32_t ncStride() const {
167 return roundUp(nc(), nr());
168 }
169
kr() const170 inline uint32_t kr() const {
171 return kr_;
172 }
173
kc() const174 inline uint32_t kc() const {
175 return kc_;
176 }
177
kcStride() const178 inline uint32_t kcStride() const {
179 return roundUp(kc(), kr());
180 }
181
quantizationParams() const182 inline const pytorch_qnnp_conv_quantization_params* quantizationParams()
183 const {
184 return &quantizationParams_;
185 }
186
187 protected:
188 std::vector<uint8_t> a_;
189 std::vector<uint8_t> k_;
190 std::vector<int32_t> b_;
191 std::vector<uint8_t, AlignedAllocator<uint8_t, 32>> w_;
192 std::vector<uint8_t> c_;
193 uint32_t mr_{0};
194 uint32_t nr_{0};
195 uint32_t np_{0};
196 uint32_t kr_{0};
197 uint32_t mc_{mr_};
198 uint32_t nc_{nr_};
199 uint32_t kc_{kr_};
200 pytorch_qnnp_conv_quantization_params quantizationParams_;
201 };
202
203 template <uint32_t MR, uint32_t NR, uint32_t NP, uint32_t KR>
204 class Q8GEMM_Op : public Q8GEMM {
205 public:
Q8GEMM_Op()206 inline Q8GEMM_Op() : Q8GEMM(MR, NR, NP, KR) {}
207
SetUp(const benchmark::State & state)208 void SetUp(const benchmark::State& state) override {
209 mc_ = state.range(0);
210 nc_ = state.range(1);
211 kc_ = state.range(2);
212
213 Q8GEMM::SetUp(state);
214 }
215 };
216
217 class Q8GEMMSparse : public benchmark::Fixture {
218 public:
Q8GEMMSparse(uint32_t mr,uint32_t nr,uint32_t kr,uint32_t rbs,uint32_t cbs)219 inline Q8GEMMSparse(
220 uint32_t mr, uint32_t nr, uint32_t kr, uint32_t rbs, uint32_t cbs)
221 :
222 mr_(mr),
223 nr_(nr),
224 kr_(kr),
225 mc_(mr),
226 nc_(nr),
227 kc_(kr),
228 row_block_size_(rbs),
229 col_block_size_(cbs){}
230
SetUp(const benchmark::State &)231 void SetUp(const benchmark::State&) override {
232 std::random_device randomDevice;
233 auto rng = std::mt19937(randomDevice());
234 auto s32rng =
235 std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
236 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
237 auto f32rng =
238 std::bind(std::uniform_real_distribution<float>(1, 5), rng);
239
240 a_.resize(mc() * kc());
241 std::generate(a_.begin(), a_.end(), std::ref(u8rng));
242 k_.resize(nc() * kc());
243 b_.resize(nc());
244 std::generate(b_.begin(), b_.end(), std::ref(f32rng));
245 size_t num_zero_points_kernel = (nc_ + (nr_ -1)) & -nr_;
246 std::vector<uint8_t> kernel_zero_points(num_zero_points_kernel, 127);
247
248 std::generate(k_.begin(), k_.end(), std::ref(u8rng));
249 fillBlockSparseWeights(
250 k_.data(),
251 nc(),
252 kc(),
253 rowBlockSize(),
254 colBlockSize(),
255 sparsity(),
256 kernel_zero_points.data());
257 bcsr_matrix_ = qnnpack::generateBlockCSRMatrix<uint32_t>(
258 k_.data(),
259 nc(),
260 kc(),
261 rowBlockSize(),
262 colBlockSize(),
263 kernel_zero_points.data());
264 std::vector<float> dequantization_scales(num_zero_points_kernel, 0.75f);
265 c_.resize(mc() * nc());
266 std::fill(c_.begin(), c_.end(), 0xA5);
267
268 quantizationParams_ = pytorch_qnnp_conv_dynamic_quantization_params{
269 127,
270 kernel_zero_points.data(),
271 dequantization_scales.data(),
272 };
273 }
274
TearDown(benchmark::State & state)275 void TearDown(benchmark::State& state) override {
276 state.SetItemsProcessed(
277 uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
278 a_.clear();
279 k_.clear();
280 b_.clear();
281 c_.clear();
282 }
283
a() const284 inline const uint8_t* a() const {
285 return a_.data();
286 }
287
k() const288 inline const uint8_t* k() const {
289 return k_.data();
290 }
291
b() const292 inline const float* b() const {
293 return b_.data();
294 }
295
c()296 inline float* c() {
297 return c_.data();
298 }
299
mr() const300 inline uint32_t mr() const {
301 return mr_;
302 }
303
mc() const304 inline uint32_t mc() const {
305 return mc_;
306 }
307
nr() const308 inline uint32_t nr() const {
309 return nr_;
310 }
311
nc() const312 inline uint32_t nc() const {
313 return nc_;
314 }
315
ncStride() const316 inline uint32_t ncStride() const {
317 return roundUp(nc(), nr());
318 }
319
kr() const320 inline uint32_t kr() const {
321 return kr_;
322 }
323
kc() const324 inline uint32_t kc() const {
325 return kc_;
326 }
327
kcStride() const328 inline uint32_t kcStride() const {
329 return roundUp(kc(), kr());
330 }
331
rowBlockSize() const332 inline size_t rowBlockSize() const {
333 return this->row_block_size_;
334 }
335
colBlockSize() const336 inline size_t colBlockSize() const {
337 return this->col_block_size_;
338 }
339
sparsity() const340 inline float sparsity() const {
341 return this->sparsity_;
342 }
343
quantizationParams() const344 inline const pytorch_qnnp_conv_dynamic_quantization_params* quantizationParams()
345 const {
346 return &quantizationParams_;
347 }
348
349 protected:
350 std::vector<uint8_t> a_;
351 std::vector<uint8_t> k_;
352 std::vector<float> b_;
353 std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_;
354 std::vector<float> c_;
355 uint32_t mr_{0};
356 uint32_t nr_{0};
357 uint32_t kr_{0};
358 uint32_t mc_{mr_};
359 uint32_t nc_{nr_};
360 uint32_t kc_{kr_};
361 uint32_t row_block_size_{1};
362 uint32_t col_block_size_{4};
363 float sparsity_{0.7f};
364 pytorch_qnnp_conv_dynamic_quantization_params quantizationParams_;
365 };
366
367 template <uint32_t MR, uint32_t NR, uint32_t KR, uint32_t RBS, uint32_t CBS>
368 class Q8GEMMSparse_Op : public Q8GEMMSparse {
369 public:
Q8GEMMSparse_Op()370 inline Q8GEMMSparse_Op() : Q8GEMMSparse(MR, NR, KR, RBS, CBS) {}
371
SetUp(const benchmark::State & state)372 void SetUp(const benchmark::State& state) override {
373 mc_ = state.range(0);
374 nc_ = state.range(1);
375 kc_ = state.range(2);
376
377 Q8GEMMSparse::SetUp(state);
378 }
379 };
380
SparseGEMMBenchGemmArguments(benchmark::internal::Benchmark * b)381 static void SparseGEMMBenchGemmArguments(benchmark::internal::Benchmark* b) {
382 b->ArgNames({"M", "N", "K"});
383
384 b->Args({5, 4096, 640});
385 b->Args({20, 4096, 640});
386 b->Args({4, 4096, 1024});
387 b->Args({3, 4096, 1024});
388 b->Args({5, 1024, 640});
389 b->Args({5, 4096, 1280});
390 b->Args({20, 4096, 880});
391 b->Args({10, 4096, 640});
392 b->Args({10, 4096, 1280});
393 b->Args({5, 4096, 1024});
394 b->Args({6, 4096, 1024});
395 b->Args({7, 4096, 1024});
396 b->Args({8, 4096, 1024});
397 b->Args({9, 4096, 1024});
398 b->Args({7, 4096, 640});
399 b->Args({4, 4096, 640});
400 b->Args({28, 4096, 640});
401 b->Args({16, 4096, 640});
402 b->Args({10, 4096, 1024});
403 b->Args({8, 4096, 640});
404 b->Args({8, 4096, 1280});
405 b->Args({7, 1024, 640});
406 b->Args({7, 4096, 1280});
407 b->Args({4, 1024, 640});
408 b->Args({4, 4096, 1280});
409 b->Args({28, 4096, 880});
410 b->Args({16, 4096, 880});
411 b->Args({14, 4096, 640});
412 b->Args({14, 4096, 1280});
413 }
414
415 #if CPUINFO_ARCH_ARM
416 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 4x8__aarch32_neon, 4, 8, 8, 1)
417 (benchmark::State& state) {
418 for (auto _ : state) {
419 for (uint32_t m = 0; m < mc(); m += mr()) {
420 const uint32_t mrr = min(mc() - m, mr());
421 for (uint32_t n = 0, channel_offset = 0; n < nc();
422 n += nr(), channel_offset += nr()) {
423 const uint32_t nrr = min(nc() - n, nr());
424 pytorch_q8gemm_ukernel_4x8__aarch32_neon(
425 mrr,
426 nrr,
427 kc(),
428 a() + m * kc(),
429 kc() * sizeof(uint8_t),
430 w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
431 c() + m * nc() + n,
432 nc() * sizeof(uint8_t),
433 channel_offset,
434 quantizationParams());
435 }
436 }
437 }
438 }
439
440 BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)
441 ->Apply(SparseGEMMBenchGemmArguments);
442
443 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon, 4, 8, 4, 1, 4)
444 (benchmark::State& state) {
445 for (auto _ : state) {
446 auto m_blocks = (mc() + mr() - 1) / mr();
447 auto k_blocks = (kc() + 4 - 1) / 4;
448 std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
449 for (uint32_t m = 0; m < mc(); m += mr()) {
450 const uint32_t mrr = min(mc() - m, mr());
451 for (uint32_t n = 0, channel_offset = 0; n < nc();
452 n += nr(), channel_offset += nr()) {
453 const uint32_t nrr = min(nc() - n, nr());
454 pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
455 mrr,
456 kc(),
457 a() + m * kc(),
458 kc() * sizeof(uint8_t),
459 a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
460 );
461 }
462 }
463 for (uint32_t m = 0; m < mc(); m += mr()) {
464 const uint32_t mrr = min(mc() - m, mr());
465 for (uint32_t n = 0, channel_offset = 0; n < nc();
466 n += nr(), channel_offset += nr()) {
467 const uint32_t nrr = min(nc() - n, nr());
468 pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA_w32__aarch32_neon(
469 mrr,
470 nrr,
471 a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
472 bcsr_matrix_->values.data(),
473 static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()) +
474 n,
475 static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
476 b() + n,
477 c() + m * nc() + n,
478 nc(),
479 channel_offset,
480 quantizationParams());
481 }
482 }
483 }
484 }
485 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon)
486 ->Apply(SparseGEMMBenchGemmArguments);
487
488
489 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c8x1_prepacked__aarch32_neon, 4, 8, 1, 8, 1)
490 (benchmark::State& state) {
491 for (auto _ : state) {
492 auto m_blocks = (mc() + mr() - 1) / mr();
493 // Still use kr of 4 because we use 4x4 packing kernel
494 auto k_blocks = (kc() + 4 - 1) / 4;
495 std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
496 for (uint32_t m = 0; m < mc(); m += mr()) {
497 const uint32_t mrr = min(mc() - m, mr());
498 for (uint32_t n = 0, channel_offset = 0; n < nc();
499 n += nr(), channel_offset += nr()) {
500 const uint32_t nrr = min(nc() - n, nr());
501 pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
502 mrr,
503 kc(),
504 a() + m * kc(),
505 kc() * sizeof(uint8_t),
506 a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
507 );
508 }
509 }
510 for (uint32_t m = 0; m < mc(); m += mr()) {
511 const uint32_t mrr = min(mc() - m, mr());
512 for (uint32_t n = 0, channel_offset = 0; n < nc();
513 n += nr(), channel_offset += nr()) {
514 const uint32_t nrr = min(nc() - n, nr());
515 pytorch_q8gemm_dq_sparse_8x1_ukernel_4x8_packedA_w32__aarch32_neon(
516 mrr,
517 nrr,
518 a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
519 bcsr_matrix_->values.data(),
520 static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()) +
521 (n >> 3),
522 static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
523 b() + n,
524 c() + m * nc() + n,
525 nc(),
526 channel_offset,
527 quantizationParams());
528 }
529 }
530 }
531 }
532 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c8x1_prepacked__aarch32_neon)
533 ->Apply(SparseGEMMBenchGemmArguments);
534 #endif
535
536 #if CPUINFO_ARCH_ARM64
537 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__aarch64_neon, 8, 8, 8, 1)
538 (benchmark::State& state) {
539 for (auto _ : state) {
540 for (uint32_t m = 0; m < mc(); m += mr()) {
541 const uint32_t mrr = min(mc() - m, mr());
542 for (uint32_t n = 0, channel_offset = 0; n < nc();
543 n += nr(), channel_offset += nr()) {
544 const uint32_t nrr = min(nc() - n, nr());
545 pytorch_q8gemm_ukernel_8x8__aarch64_neon(
546 mrr,
547 nrr,
548 kc(),
549 a() + m * kc(),
550 kc() * sizeof(uint8_t),
551 w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
552 c() + m * nc() + n,
553 nc() * sizeof(uint8_t),
554 channel_offset,
555 quantizationParams());
556 }
557 }
558 }
559 }
560
561 BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon)
562 ->Apply(SparseGEMMBenchGemmArguments);
563
564 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x8c1x4_prepacked__aarch64_neon, 8, 8, 4, 1, 4)
565 (benchmark::State& state) {
566 for (auto _ : state) {
567 auto m_blocks = (mc() + mr() - 1) / mr();
568 auto k_blocks = (kc() + 4 - 1) / 4;
569 std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
570 for (uint32_t m = 0; m < mc(); m += mr()) {
571 const uint32_t mrr = min(mc() - m, mr());
572 for (uint32_t n = 0, channel_offset = 0; n < nc();
573 n += nr(), channel_offset += nr()) {
574 const uint32_t nrr = min(nc() - n, nr());
575 pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon(
576 mrr,
577 kc(),
578 a() + m * kc(),
579 kc() * sizeof(uint8_t),
580 a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
581 );
582 }
583 }
584 for (uint32_t m = 0; m < mc(); m += mr()) {
585 const uint32_t mrr = min(mc() - m, mr());
586 for (uint32_t n = 0, channel_offset = 0; n < nc();
587 n += nr(), channel_offset += nr()) {
588 const uint32_t nrr = min(nc() - n, nr());
589 pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon(
590 mrr,
591 nrr,
592 a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
593 bcsr_matrix_->values.data(),
594 static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()),
595 static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
596 b() + n,
597 c() + m * nc() + n,
598 nc(),
599 channel_offset,
600 quantizationParams());
601 }
602 }
603 }
604 }
605 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x8c1x4_prepacked__aarch64_neon)
606 ->Apply(SparseGEMMBenchGemmArguments);
607
608 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x8c8x1_prepacked__aarch64_neon, 8, 8, 4, 8, 1)
609 (benchmark::State& state) {
610 for (auto _ : state) {
611 auto m_blocks = (mc() + mr() - 1) / mr();
612 // Still use kr of 4 because we use 4x4 packing kernel
613 auto k_blocks = (kc() + 4 - 1) / 4;
614 std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
615 for (uint32_t m = 0; m < mc(); m += mr()) {
616 const uint32_t mrr = min(mc() - m, mr());
617 for (uint32_t n = 0, channel_offset = 0; n < nc();
618 n += nr(), channel_offset += nr()) {
619 const uint32_t nrr = min(nc() - n, nr());
620 pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon(
621 mrr,
622 kc(),
623 a() + m * kc(),
624 kc() * sizeof(uint8_t),
625 a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
626 );
627 }
628 }
629 for (uint32_t m = 0; m < mc(); m += mr()) {
630 const uint32_t mrr = min(mc() - m, mr());
631 for (uint32_t n = 0, channel_offset = 0; n < nc();
632 n += nr(), channel_offset += nr()) {
633 const uint32_t nrr = min(nc() - n, nr());
634 pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon(
635 mrr,
636 nrr,
637 a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
638 bcsr_matrix_->values.data(),
639 static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()),
640 static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
641 b() + n,
642 c() + m * nc() + n,
643 nc(),
644 channel_offset,
645 quantizationParams());
646 }
647 }
648 }
649 }
650 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x8c8x1_prepacked__aarch64_neon)
651 ->Apply(SparseGEMMBenchGemmArguments);
652
653 #endif
654
655 #ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN
656 BENCHMARK_MAIN();
657 #endif
658