• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <cstddef>
12 
13 #include <xnnpack/math.h>
14 #include <xnnpack/params-init.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/requantization.h>
17 
18 
19 class GemmMicrokernelTester {
20  public:
mr(size_t mr)21   inline GemmMicrokernelTester& mr(size_t mr) {
22     this->mr_ = mr;
23     return *this;
24   }
25 
mr()26   inline size_t mr() const {
27     return this->mr_;
28   }
29 
nr(size_t nr)30   inline GemmMicrokernelTester& nr(size_t nr) {
31     this->nr_ = nr;
32     return *this;
33   }
34 
nr()35   inline size_t nr() const {
36     return this->nr_;
37   }
38 
39 
kr(size_t kr)40   inline GemmMicrokernelTester& kr(size_t kr) {
41     this->kr_ = kr;
42     return *this;
43   }
44 
kr()45   inline size_t kr() const {
46     return this->kr_;
47   }
48 
sr(size_t sr)49   inline GemmMicrokernelTester& sr(size_t sr) {
50     this->sr_ = sr;
51     return *this;
52   }
53 
sr()54   inline size_t sr() const {
55     return this->sr_;
56   }
57 
m(size_t m)58   inline GemmMicrokernelTester& m(size_t m) {
59     this->m_ = m;
60     return *this;
61   }
62 
m()63   inline size_t m() const {
64     return this->m_;
65   }
66 
n(size_t n)67   inline GemmMicrokernelTester& n(size_t n) {
68     this->n_ = n;
69     return *this;
70   }
71 
n()72   inline size_t n() const {
73     return this->n_;
74   }
75 
k(size_t k)76   inline GemmMicrokernelTester& k(size_t k) {
77     this->k_ = k;
78     return *this;
79   }
80 
k()81   inline size_t k() const {
82     return this->k_;
83   }
84 
ks(size_t ks)85   inline GemmMicrokernelTester& ks(size_t ks) {
86     this->ks_ = ks;
87     return *this;
88   }
89 
ks()90   inline size_t ks() const {
91     return this->ks_;
92   }
93 
packed_k()94   inline size_t packed_k() const {
95     return round_up_po2(k(), kr() * sr());
96   }
97 
packed_n()98   inline size_t packed_n() const {
99     return round_up(n(), nr());
100   }
101 
a_stride(size_t a_stride)102   inline GemmMicrokernelTester& a_stride(size_t a_stride) {
103     this->a_stride_ = a_stride;
104     return *this;
105   }
106 
a_stride()107   inline size_t a_stride() const {
108     return this->a_stride_ == 0 ? k() : this->a_stride_;
109   }
110 
cm_stride(size_t cm_stride)111   inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
112     this->cm_stride_ = cm_stride;
113     return *this;
114   }
115 
cm_stride()116   inline size_t cm_stride() const {
117     return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
118   }
119 
cn_stride(size_t cn_stride)120   inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
121     this->cn_stride_ = cn_stride;
122     return *this;
123   }
124 
cn_stride()125   inline size_t cn_stride() const {
126     return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
127   }
128 
a_zero_point(uint8_t a_zero_point)129   inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
130     this->a_zero_point_ = a_zero_point;
131     return *this;
132   }
133 
a_zero_point()134   inline uint8_t a_zero_point() const {
135     return this->a_zero_point_;
136   }
137 
b_zero_point(uint8_t b_zero_point)138   inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
139     this->b_zero_point_ = b_zero_point;
140     return *this;
141   }
142 
b_zero_point()143   inline uint8_t b_zero_point() const {
144     return this->b_zero_point_;
145   }
146 
qmin(uint8_t qmin)147   inline GemmMicrokernelTester& qmin(uint8_t qmin) {
148     this->qmin_ = qmin;
149     return *this;
150   }
151 
qmin()152   inline uint8_t qmin() const {
153     return this->qmin_;
154   }
155 
qmax(uint8_t qmax)156   inline GemmMicrokernelTester& qmax(uint8_t qmax) {
157     this->qmax_ = qmax;
158     return *this;
159   }
160 
qmax()161   inline uint8_t qmax() const {
162     return this->qmax_;
163   }
164 
a_offset(size_t a_offset)165   inline GemmMicrokernelTester& a_offset(size_t a_offset) {
166     this->a_offset_ = a_offset;
167     return *this;
168   }
169 
a_offset()170   inline size_t a_offset() const {
171     return this->a_offset_;
172   }
173 
zero_index(size_t zero_index)174   inline GemmMicrokernelTester& zero_index(size_t zero_index) {
175     this->zero_index_ = zero_index;
176     return *this;
177   }
178 
zero_index()179   inline size_t zero_index() const {
180     return this->zero_index_;
181   }
182 
extended_weights(bool extended_weights)183   inline GemmMicrokernelTester& extended_weights(bool extended_weights) {
184     this->extended_weights_ = extended_weights;
185     return *this;
186   }
187 
extended_weights()188   inline bool extended_weights() const {
189     return this->extended_weights_;
190   }
191 
iterations(size_t iterations)192   inline GemmMicrokernelTester& iterations(size_t iterations) {
193     this->iterations_ = iterations;
194     return *this;
195   }
196 
iterations()197   inline size_t iterations() const {
198     return this->iterations_;
199   }
200 
201   void Test(
202     xnn_qu8_gemm_minmax_ukernel_function gemm,
203     xnn_init_qu8_conv_minmax_params_fn init_params,
204     xnn_qu8_requantize_fn requantize) const;
205 
206   void Test(
207     xnn_qu8_igemm_minmax_ukernel_function igemm,
208     xnn_init_qu8_conv_minmax_params_fn init_params,
209     xnn_qu8_requantize_fn requantize);
210 
211   void Test(
212     xnn_qc8_gemm_minmax_ukernel_function gemm,
213     xnn_init_qs8_minmax_params_fn init_params,
214     xnn_qs8_requantize_fn requantize) const;
215 
216   void Test(
217     xnn_qc8_igemm_minmax_ukernel_function igemm,
218     xnn_init_qs8_minmax_params_fn init_params,
219     xnn_qs8_requantize_fn requantize) const;
220 
221   void Test(
222     xnn_qs8_gemm_minmax_ukernel_function gemm,
223     xnn_init_qs8_conv_minmax_params_fn init_params,
224     xnn_qs8_requantize_fn requantize) const;
225 
226   void Test(
227     xnn_qs8_igemm_minmax_ukernel_function igemm,
228     xnn_init_qs8_conv_minmax_params_fn init_params,
229     xnn_qs8_requantize_fn requantize) const;
230 
231   void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const;
232 
233   void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const;
234 
235   void Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
236 
237   void Test(xnn_f32_gemm_ukernel_function gemm) const;
238 
239   void Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const;
240 
241   void Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
242 
243   void Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const;
244 
245   void Test(xnn_f32_igemm_ukernel_function igemm) const;
246 
247   void Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const;
248 
249   void Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
250 
251 #if XNN_PLATFORM_JIT
252   void Test(
253     xnn_jit_gemm_code_generator_function gemm_generator,
254     xnn_init_f32_minmax_params_fn init_params) const;
255   void Test(
256     xnn_jit_igemm_code_generator_function igemm_generator,
257     xnn_init_f32_minmax_params_fn init_params) const;
258   void Test(
259     xnn_jit_gemm_code_generator_function gemm_generator,
260     xnn_init_qs8_minmax_params_fn init_params,
261     xnn_qs8_requantize_fn requantize) const;
262   void Test(
263     xnn_jit_igemm_code_generator_function igemm_generator,
264     xnn_init_qs8_minmax_params_fn init_params,
265     xnn_qs8_requantize_fn requantize) const;
266   void Test(
267     xnn_jit_gemm_code_generator_function gemm_generator,
268     xnn_init_qs8_conv_minmax_params_fn init_params,
269     xnn_qs8_requantize_fn requantize) const;
270   void Test(
271     xnn_jit_igemm_code_generator_function igemm_generator,
272     xnn_init_qs8_conv_minmax_params_fn init_params,
273     xnn_qs8_requantize_fn requantize) const;
274 #endif  // XNN_PLATFORM_JIT
275 
276  private:
277   size_t mr_{1};
278   size_t nr_{1};
279   size_t kr_{1};
280   size_t sr_{1};
281   size_t m_{1};
282   size_t n_{1};
283   size_t k_{1};
284   size_t ks_{1};
285   size_t a_stride_{0};
286   size_t cm_stride_{0};
287   size_t cn_stride_{0};
288   uint8_t a_zero_point_{127};
289   uint8_t b_zero_point_{127};
290   uint8_t qmin_{0};
291   uint8_t qmax_{255};
292   size_t a_offset_{0};
293   size_t zero_index_{SIZE_MAX};
294   bool extended_weights_{false};
295   size_t iterations_{15};
296 };
297