• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <algorithm>
10 #include <cfloat>
11 #include <chrono>
12 #include <cmath>
13 #include <functional>
14 #include <iostream>
15 #include <random>
16 #include <vector>
17 
18 #include <pack_block_sparse.h>
19 #include <qnnpack/AlignedAllocator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/q8gemm.h>
23 #include <qnnpack/q8gemm_sparse.h>
24 #include <qnnpack/requantization.h>
25 
26 #include <benchmark/benchmark.h>
27 
28 namespace {
divideRoundUp(uint32_t x,uint32_t q)29   inline uint32_t divideRoundUp(uint32_t x, uint32_t q) {
30     return x / q + uint32_t(x % q != 0);
31   }
32 
roundUp(uint32_t x,uint32_t q)33   inline uint32_t roundUp(uint32_t x, uint32_t q) {
34     return q * divideRoundUp(x, q);
35   }
36 
fillBlockSparseWeights(uint8_t * b,size_t N,size_t K,size_t row_block_size,size_t col_block_size,float sparsity,const uint8_t * zero_points)37   void fillBlockSparseWeights(
38       uint8_t* b,
39       size_t N,
40       size_t K,
41       size_t row_block_size,
42       size_t col_block_size,
43       float sparsity,
44       const uint8_t* zero_points) {
45     std::random_device randomDevice;
46     auto rng = std::mt19937(randomDevice());
47     std::bernoulli_distribution dist{sparsity};
48     for (uint32_t n = 0; n < N ; n += row_block_size) {
49       for (uint32_t k = 0; k < K; k += col_block_size) {
50         if (dist(rng)) {
51           for (uint32_t nb = 0; (nb < row_block_size) && (n + nb < N); ++nb) {
52             for (uint32_t kb = 0; (kb < col_block_size) && (k + kb < K); ++kb) {
53               *(b + (n + nb) * K + k + kb) = zero_points[n + nb];
54             }
55           }
56         }
57       }
58     }
59   }
60 
61 }
62 
63 class Q8GEMM : public benchmark::Fixture {
64  public:
Q8GEMM(uint32_t mr,uint32_t nr,uint32_t np,uint32_t kr)65   inline Q8GEMM(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr)
66       : mr_(mr), nr_(nr), np_(np), kr_(kr), mc_(mr), nc_(nr), kc_(kr) {}
67 
SetUp(const benchmark::State &)68    void SetUp(const benchmark::State&) override {
69     std::random_device randomDevice;
70     auto rng = std::mt19937(randomDevice());
71     auto s32rng =
72         std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
73     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
74 
75     a_.resize(mc() * kc());
76     std::generate(a_.begin(), a_.end(), std::ref(u8rng));
77     k_.resize(nc() * kc());
78     std::generate(k_.begin(), k_.end(), std::ref(u8rng));
79     b_.resize(nc());
80     std::generate(b_.begin(), b_.end(), std::ref(s32rng));
81     w_.resize(
82         kcStride() * ncStride() +
83         ncStride() * sizeof(int32_t) / sizeof(uint8_t));
84     std::fill(w_.begin(), w_.end(), 127);
85     size_t num_zero_points_kernel = (nc_ + (nr_ -1)) & -nr_;
86     std::vector<uint8_t> kernel_zero_points(num_zero_points_kernel, 127);
87     std::vector<float> requantization_scales(num_zero_points_kernel, 0.75f);
88     pytorch_pack_q8gemm_w(
89         nc(),
90         kc(),
91         nr(),
92         np(),
93         kr(),
94 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
95         127,
96         127,
97 #endif
98         k(),
99         b(),
100 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
101         kernel_zero_points.data(),
102 #endif
103         w());
104     c_.resize(mc() * nc());
105     std::fill(c_.begin(), c_.end(), 0xA5);
106 
107     quantizationParams_ = pytorch_qnnp_compute_conv_quantization_params(
108         127, kernel_zero_points.data(),
109         requantization_scales.data(), 127, 1, 254);
110   }
111 
TearDown(benchmark::State & state)112    void TearDown(benchmark::State& state) override {
113     state.SetItemsProcessed(
114         uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
115     a_.clear();
116     k_.clear();
117     b_.clear();
118     w_.clear();
119     c_.clear();
120   }
121 
a() const122   inline const uint8_t* a() const {
123     return a_.data();
124   }
125 
k() const126   inline const uint8_t* k() const {
127     return k_.data();
128   }
129 
b() const130   inline const int32_t* b() const {
131     return b_.data();
132   }
133 
w()134   inline uint8_t* w() {
135     return w_.data();
136   }
137 
w() const138   inline const uint8_t* w() const {
139     return w_.data();
140   }
141 
c()142   inline uint8_t* c() {
143     return c_.data();
144   }
145 
mr() const146   inline uint32_t mr() const {
147     return mr_;
148   }
149 
mc() const150   inline uint32_t mc() const {
151     return mc_;
152   }
153 
nr() const154   inline uint32_t nr() const {
155     return nr_;
156   }
157 
np() const158   inline uint32_t np() const {
159     return np_;
160   }
161 
nc() const162   inline uint32_t nc() const {
163     return nc_;
164   }
165 
ncStride() const166   inline uint32_t ncStride() const {
167     return roundUp(nc(), nr());
168   }
169 
kr() const170   inline uint32_t kr() const {
171     return kr_;
172   }
173 
kc() const174   inline uint32_t kc() const {
175     return kc_;
176   }
177 
kcStride() const178   inline uint32_t kcStride() const {
179     return roundUp(kc(), kr());
180   }
181 
quantizationParams() const182   inline const pytorch_qnnp_conv_quantization_params* quantizationParams()
183       const {
184     return &quantizationParams_;
185   }
186 
187  protected:
188   std::vector<uint8_t> a_;
189   std::vector<uint8_t> k_;
190   std::vector<int32_t> b_;
191   std::vector<uint8_t, AlignedAllocator<uint8_t, 32>> w_;
192   std::vector<uint8_t> c_;
193   uint32_t mr_{0};
194   uint32_t nr_{0};
195   uint32_t np_{0};
196   uint32_t kr_{0};
197   uint32_t mc_{mr_};
198   uint32_t nc_{nr_};
199   uint32_t kc_{kr_};
200   pytorch_qnnp_conv_quantization_params quantizationParams_;
201 };
202 
203 template <uint32_t MR, uint32_t NR, uint32_t NP, uint32_t KR>
204 class Q8GEMM_Op : public Q8GEMM {
205  public:
Q8GEMM_Op()206   inline Q8GEMM_Op() : Q8GEMM(MR, NR, NP, KR) {}
207 
SetUp(const benchmark::State & state)208    void SetUp(const benchmark::State& state) override {
209     mc_ = state.range(0);
210     nc_ = state.range(1);
211     kc_ = state.range(2);
212 
213     Q8GEMM::SetUp(state);
214   }
215 };
216 
217 class Q8GEMMSparse : public benchmark::Fixture {
218  public:
Q8GEMMSparse(uint32_t mr,uint32_t nr,uint32_t kr,uint32_t rbs,uint32_t cbs)219   inline Q8GEMMSparse(
220       uint32_t mr, uint32_t nr, uint32_t kr, uint32_t rbs, uint32_t cbs)
221       :
222         mr_(mr),
223         nr_(nr),
224         kr_(kr),
225         mc_(mr),
226         nc_(nr),
227         kc_(kr),
228         row_block_size_(rbs),
229         col_block_size_(cbs){}
230 
SetUp(const benchmark::State &)231    void SetUp(const benchmark::State&) override {
232     std::random_device randomDevice;
233     auto rng = std::mt19937(randomDevice());
234     auto s32rng =
235         std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
236     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
237     auto f32rng =
238         std::bind(std::uniform_real_distribution<float>(1, 5), rng);
239 
240     a_.resize(mc() * kc());
241     std::generate(a_.begin(), a_.end(), std::ref(u8rng));
242     k_.resize(nc() * kc());
243     b_.resize(nc());
244     std::generate(b_.begin(), b_.end(), std::ref(f32rng));
245     size_t num_zero_points_kernel = (nc_ + (nr_ -1)) & -nr_;
246     std::vector<uint8_t> kernel_zero_points(num_zero_points_kernel, 127);
247 
248     std::generate(k_.begin(), k_.end(), std::ref(u8rng));
249     fillBlockSparseWeights(
250         k_.data(),
251         nc(),
252         kc(),
253         rowBlockSize(),
254         colBlockSize(),
255         sparsity(),
256         kernel_zero_points.data());
257     bcsr_matrix_ = qnnpack::generateBlockCSRMatrix<uint32_t>(
258         k_.data(),
259         nc(),
260         kc(),
261         rowBlockSize(),
262         colBlockSize(),
263         kernel_zero_points.data());
264     std::vector<float> dequantization_scales(num_zero_points_kernel, 0.75f);
265     c_.resize(mc() * nc());
266     std::fill(c_.begin(), c_.end(), 0xA5);
267 
268     quantizationParams_ = pytorch_qnnp_conv_dynamic_quantization_params{
269       127,
270       kernel_zero_points.data(),
271       dequantization_scales.data(),
272     };
273   }
274 
TearDown(benchmark::State & state)275    void TearDown(benchmark::State& state) override {
276     state.SetItemsProcessed(
277         uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
278     a_.clear();
279     k_.clear();
280     b_.clear();
281     c_.clear();
282   }
283 
a() const284   inline const uint8_t* a() const {
285     return a_.data();
286   }
287 
k() const288   inline const uint8_t* k() const {
289     return k_.data();
290   }
291 
b() const292   inline const float* b() const {
293     return b_.data();
294   }
295 
c()296   inline float* c() {
297     return c_.data();
298   }
299 
mr() const300   inline uint32_t mr() const {
301     return mr_;
302   }
303 
mc() const304   inline uint32_t mc() const {
305     return mc_;
306   }
307 
nr() const308   inline uint32_t nr() const {
309     return nr_;
310   }
311 
nc() const312   inline uint32_t nc() const {
313     return nc_;
314   }
315 
ncStride() const316   inline uint32_t ncStride() const {
317     return roundUp(nc(), nr());
318   }
319 
kr() const320   inline uint32_t kr() const {
321     return kr_;
322   }
323 
kc() const324   inline uint32_t kc() const {
325     return kc_;
326   }
327 
kcStride() const328   inline uint32_t kcStride() const {
329     return roundUp(kc(), kr());
330   }
331 
rowBlockSize() const332   inline size_t rowBlockSize() const {
333     return this->row_block_size_;
334   }
335 
colBlockSize() const336   inline size_t colBlockSize() const {
337     return this->col_block_size_;
338   }
339 
sparsity() const340   inline float sparsity() const {
341     return this->sparsity_;
342   }
343 
quantizationParams() const344   inline const pytorch_qnnp_conv_dynamic_quantization_params* quantizationParams()
345       const {
346     return &quantizationParams_;
347   }
348 
349  protected:
350   std::vector<uint8_t> a_;
351   std::vector<uint8_t> k_;
352   std::vector<float> b_;
353   std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix_;
354   std::vector<float> c_;
355   uint32_t mr_{0};
356   uint32_t nr_{0};
357   uint32_t kr_{0};
358   uint32_t mc_{mr_};
359   uint32_t nc_{nr_};
360   uint32_t kc_{kr_};
361   uint32_t row_block_size_{1};
362   uint32_t col_block_size_{4};
363   float sparsity_{0.7f};
364   pytorch_qnnp_conv_dynamic_quantization_params quantizationParams_;
365 };
366 
367 template <uint32_t MR, uint32_t NR, uint32_t KR, uint32_t RBS, uint32_t CBS>
368 class Q8GEMMSparse_Op : public Q8GEMMSparse {
369  public:
Q8GEMMSparse_Op()370   inline Q8GEMMSparse_Op() : Q8GEMMSparse(MR, NR, KR, RBS, CBS) {}
371 
SetUp(const benchmark::State & state)372    void SetUp(const benchmark::State& state) override {
373     mc_ = state.range(0);
374     nc_ = state.range(1);
375     kc_ = state.range(2);
376 
377     Q8GEMMSparse::SetUp(state);
378   }
379 };
380 
SparseGEMMBenchGemmArguments(benchmark::internal::Benchmark * b)381 static void SparseGEMMBenchGemmArguments(benchmark::internal::Benchmark* b) {
382   b->ArgNames({"M", "N", "K"});
383 
384   b->Args({5, 4096, 640});
385   b->Args({20, 4096, 640});
386   b->Args({4, 4096, 1024});
387   b->Args({3, 4096, 1024});
388   b->Args({5, 1024, 640});
389   b->Args({5, 4096, 1280});
390   b->Args({20, 4096, 880});
391   b->Args({10, 4096, 640});
392   b->Args({10, 4096, 1280});
393   b->Args({5, 4096, 1024});
394   b->Args({6, 4096, 1024});
395   b->Args({7, 4096, 1024});
396   b->Args({8, 4096, 1024});
397   b->Args({9, 4096, 1024});
398   b->Args({7, 4096, 640});
399   b->Args({4, 4096, 640});
400   b->Args({28, 4096, 640});
401   b->Args({16, 4096, 640});
402   b->Args({10, 4096, 1024});
403   b->Args({8, 4096, 640});
404   b->Args({8, 4096, 1280});
405   b->Args({7, 1024, 640});
406   b->Args({7, 4096, 1280});
407   b->Args({4, 1024, 640});
408   b->Args({4, 4096, 1280});
409   b->Args({28, 4096, 880});
410   b->Args({16, 4096, 880});
411   b->Args({14, 4096, 640});
412   b->Args({14, 4096, 1280});
413 }
414 
415 #if CPUINFO_ARCH_ARM
416 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 4x8__aarch32_neon, 4, 8, 8, 1)
417 (benchmark::State& state) {
418   for (auto _ : state) {
419     for (uint32_t m = 0; m < mc(); m += mr()) {
420       const uint32_t mrr = min(mc() - m, mr());
421       for (uint32_t n = 0, channel_offset = 0; n < nc();
422           n += nr(), channel_offset += nr()) {
423         const uint32_t nrr = min(nc() - n, nr());
424         pytorch_q8gemm_ukernel_4x8__aarch32_neon(
425             mrr,
426             nrr,
427             kc(),
428             a() + m * kc(),
429             kc() * sizeof(uint8_t),
430             w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
431             c() + m * nc() + n,
432             nc() * sizeof(uint8_t),
433             channel_offset,
434             quantizationParams());
435       }
436     }
437   }
438 }
439 
440 BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)
441     ->Apply(SparseGEMMBenchGemmArguments);
442 
443 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon, 4, 8, 4, 1, 4)
444 (benchmark::State& state) {
445   for (auto _ : state) {
446     auto m_blocks = (mc() + mr()  - 1) / mr();
447     auto k_blocks = (kc() + 4  - 1) / 4;
448     std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
449     for (uint32_t m = 0; m < mc(); m += mr()) {
450       const uint32_t mrr = min(mc() - m, mr());
451       for (uint32_t n = 0, channel_offset = 0; n < nc();
452           n += nr(), channel_offset += nr()) {
453         const uint32_t nrr = min(nc() - n, nr());
454         pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
455             mrr,
456             kc(),
457             a() + m * kc(),
458             kc() * sizeof(uint8_t),
459             a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
460             );
461       }
462     }
463     for (uint32_t m = 0; m < mc(); m += mr()) {
464       const uint32_t mrr = min(mc() - m, mr());
465       for (uint32_t n = 0, channel_offset = 0; n < nc();
466           n += nr(), channel_offset += nr()) {
467         const uint32_t nrr = min(nc() - n, nr());
468         pytorch_q8gemm_dq_sparse_1x4_ukernel_4x8_packedA_w32__aarch32_neon(
469             mrr,
470             nrr,
471             a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
472             bcsr_matrix_->values.data(),
473             static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()) +
474                 n,
475             static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
476             b() + n,
477             c() + m * nc() + n,
478             nc(),
479             channel_offset,
480             quantizationParams());
481       }
482     }
483   }
484 }
485 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c1x4_prepacked__aarch32_neon)
486     ->Apply(SparseGEMMBenchGemmArguments);
487 
488 
489 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 4x8c8x1_prepacked__aarch32_neon, 4, 8, 1, 8, 1)
490 (benchmark::State& state) {
491   for (auto _ : state) {
492     auto m_blocks = (mc() + mr()  - 1) / mr();
493     // Still use kr of 4 because we use 4x4 packing kernel
494     auto k_blocks = (kc() + 4  - 1) / 4;
495     std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
496     for (uint32_t m = 0; m < mc(); m += mr()) {
497       const uint32_t mrr = min(mc() - m, mr());
498       for (uint32_t n = 0, channel_offset = 0; n < nc();
499           n += nr(), channel_offset += nr()) {
500         const uint32_t nrr = min(nc() - n, nr());
501         pytorch_q8gemm_sparse_packA_ukernel_4x4__aarch32_neon(
502             mrr,
503             kc(),
504             a() + m * kc(),
505             kc() * sizeof(uint8_t),
506             a_packed.data() + (m >> 2) * (k_blocks << 2) * mr()
507             );
508       }
509     }
510     for (uint32_t m = 0; m < mc(); m += mr()) {
511       const uint32_t mrr = min(mc() - m, mr());
512       for (uint32_t n = 0, channel_offset = 0; n < nc();
513           n += nr(), channel_offset += nr()) {
514         const uint32_t nrr = min(nc() - n, nr());
515         pytorch_q8gemm_dq_sparse_8x1_ukernel_4x8_packedA_w32__aarch32_neon(
516             mrr,
517             nrr,
518             a_packed.data() + (m >> 2) * (k_blocks << 2) * mr(),
519             bcsr_matrix_->values.data(),
520             static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()) +
521                 (n >> 3),
522             static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
523             b() + n,
524             c() + m * nc() + n,
525             nc(),
526             channel_offset,
527             quantizationParams());
528       }
529     }
530   }
531 }
532 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 4x8c8x1_prepacked__aarch32_neon)
533     ->Apply(SparseGEMMBenchGemmArguments);
534 #endif
535 
536 #if CPUINFO_ARCH_ARM64
537 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__aarch64_neon, 8, 8, 8, 1)
538 (benchmark::State& state) {
539   for (auto _ : state) {
540     for (uint32_t m = 0; m < mc(); m += mr()) {
541       const uint32_t mrr = min(mc() - m, mr());
542       for (uint32_t n = 0, channel_offset = 0; n < nc();
543           n += nr(), channel_offset += nr()) {
544         const uint32_t nrr = min(nc() - n, nr());
545         pytorch_q8gemm_ukernel_8x8__aarch64_neon(
546             mrr,
547             nrr,
548             kc(),
549             a() + m * kc(),
550             kc() * sizeof(uint8_t),
551             w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)),
552             c() + m * nc() + n,
553             nc() * sizeof(uint8_t),
554             channel_offset,
555             quantizationParams());
556       }
557     }
558   }
559 }
560 
561 BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon)
562     ->Apply(SparseGEMMBenchGemmArguments);
563 
564 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x8c1x4_prepacked__aarch64_neon, 8, 8, 4, 1, 4)
565 (benchmark::State& state) {
566   for (auto _ : state) {
567     auto m_blocks = (mc() + mr()  - 1) / mr();
568     auto k_blocks = (kc() + 4  - 1) / 4;
569     std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
570     for (uint32_t m = 0; m < mc(); m += mr()) {
571       const uint32_t mrr = min(mc() - m, mr());
572       for (uint32_t n = 0, channel_offset = 0; n < nc();
573           n += nr(), channel_offset += nr()) {
574         const uint32_t nrr = min(nc() - n, nr());
575         pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon(
576             mrr,
577             kc(),
578             a() + m * kc(),
579             kc() * sizeof(uint8_t),
580             a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
581             );
582       }
583     }
584     for (uint32_t m = 0; m < mc(); m += mr()) {
585       const uint32_t mrr = min(mc() - m, mr());
586       for (uint32_t n = 0, channel_offset = 0; n < nc();
587           n += nr(), channel_offset += nr()) {
588         const uint32_t nrr = min(nc() - n, nr());
589         pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w32__aarch64_neon(
590             mrr,
591             nrr,
592             a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
593             bcsr_matrix_->values.data(),
594             static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()),
595             static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
596             b() + n,
597             c() + m * nc() + n,
598             nc(),
599             channel_offset,
600             quantizationParams());
601       }
602     }
603   }
604 }
605 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x8c1x4_prepacked__aarch64_neon)
606     ->Apply(SparseGEMMBenchGemmArguments);
607 
608 BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMMSparse_Op, 8x8c8x1_prepacked__aarch64_neon, 8, 8, 4, 8, 1)
609 (benchmark::State& state) {
610   for (auto _ : state) {
611     auto m_blocks = (mc() + mr()  - 1) / mr();
612     // Still use kr of 4 because we use 4x4 packing kernel
613     auto k_blocks = (kc() + 4  - 1) / 4;
614     std::vector<uint8_t> a_packed(m_blocks * k_blocks * mr() * 4 + 8);
615     for (uint32_t m = 0; m < mc(); m += mr()) {
616       const uint32_t mrr = min(mc() - m, mr());
617       for (uint32_t n = 0, channel_offset = 0; n < nc();
618           n += nr(), channel_offset += nr()) {
619         const uint32_t nrr = min(nc() - n, nr());
620         pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch64_neon(
621             mrr,
622             kc(),
623             a() + m * kc(),
624             kc() * sizeof(uint8_t),
625             a_packed.data() + (m >> 3) * (k_blocks << 2) * mr()
626             );
627       }
628     }
629     for (uint32_t m = 0; m < mc(); m += mr()) {
630       const uint32_t mrr = min(mc() - m, mr());
631       for (uint32_t n = 0, channel_offset = 0; n < nc();
632           n += nr(), channel_offset += nr()) {
633         const uint32_t nrr = min(nc() - n, nr());
634         pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w32__aarch64_neon(
635             mrr,
636             nrr,
637             a_packed.data() + (m >> 3) * (k_blocks << 2) * mr(),
638             bcsr_matrix_->values.data(),
639             static_cast<const uint32_t*>(bcsr_matrix_->row_values_data_ptr()),
640             static_cast<const uint32_t*>(bcsr_matrix_->col_indices_data_ptr()),
641             b() + n,
642             c() + m * nc() + n,
643             nc(),
644             channel_offset,
645             quantizationParams());
646       }
647     }
648   }
649 }
650 BENCHMARK_REGISTER_F(Q8GEMMSparse_Op, 8x8c8x1_prepacked__aarch64_neon)
651     ->Apply(SparseGEMMBenchGemmArguments);
652 
653 #endif
654 
655 #ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN
656 BENCHMARK_MAIN();
657 #endif
658