1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <cstddef> 12 13 #include <xnnpack/math.h> 14 #include <xnnpack/params-init.h> 15 #include <xnnpack/params.h> 16 #include <xnnpack/requantization.h> 17 18 19 class GemmMicrokernelTester { 20 public: mr(size_t mr)21 inline GemmMicrokernelTester& mr(size_t mr) { 22 this->mr_ = mr; 23 return *this; 24 } 25 mr()26 inline size_t mr() const { 27 return this->mr_; 28 } 29 nr(size_t nr)30 inline GemmMicrokernelTester& nr(size_t nr) { 31 this->nr_ = nr; 32 return *this; 33 } 34 nr()35 inline size_t nr() const { 36 return this->nr_; 37 } 38 39 kr(size_t kr)40 inline GemmMicrokernelTester& kr(size_t kr) { 41 this->kr_ = kr; 42 return *this; 43 } 44 kr()45 inline size_t kr() const { 46 return this->kr_; 47 } 48 sr(size_t sr)49 inline GemmMicrokernelTester& sr(size_t sr) { 50 this->sr_ = sr; 51 return *this; 52 } 53 sr()54 inline size_t sr() const { 55 return this->sr_; 56 } 57 m(size_t m)58 inline GemmMicrokernelTester& m(size_t m) { 59 this->m_ = m; 60 return *this; 61 } 62 m()63 inline size_t m() const { 64 return this->m_; 65 } 66 n(size_t n)67 inline GemmMicrokernelTester& n(size_t n) { 68 this->n_ = n; 69 return *this; 70 } 71 n()72 inline size_t n() const { 73 return this->n_; 74 } 75 k(size_t k)76 inline GemmMicrokernelTester& k(size_t k) { 77 this->k_ = k; 78 return *this; 79 } 80 k()81 inline size_t k() const { 82 return this->k_; 83 } 84 ks(size_t ks)85 inline GemmMicrokernelTester& ks(size_t ks) { 86 this->ks_ = ks; 87 return *this; 88 } 89 ks()90 inline size_t ks() const { 91 return this->ks_; 92 } 93 packed_k()94 inline size_t packed_k() const { 95 return round_up_po2(k(), kr() * sr()); 96 } 97 packed_n()98 inline size_t packed_n() const { 99 return round_up(n(), nr()); 100 } 101 a_stride(size_t a_stride)102 inline GemmMicrokernelTester& a_stride(size_t a_stride) { 103 this->a_stride_ = a_stride; 104 return *this; 105 } 106 a_stride()107 inline size_t a_stride() const { 108 return this->a_stride_ == 0 ? k() : this->a_stride_; 109 } 110 cm_stride(size_t cm_stride)111 inline GemmMicrokernelTester& cm_stride(size_t cm_stride) { 112 this->cm_stride_ = cm_stride; 113 return *this; 114 } 115 cm_stride()116 inline size_t cm_stride() const { 117 return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_; 118 } 119 cn_stride(size_t cn_stride)120 inline GemmMicrokernelTester& cn_stride(size_t cn_stride) { 121 this->cn_stride_ = cn_stride; 122 return *this; 123 } 124 cn_stride()125 inline size_t cn_stride() const { 126 return this->cn_stride_ == 0 ? nr() : this->cn_stride_; 127 } 128 a_zero_point(uint8_t a_zero_point)129 inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) { 130 this->a_zero_point_ = a_zero_point; 131 return *this; 132 } 133 a_zero_point()134 inline uint8_t a_zero_point() const { 135 return this->a_zero_point_; 136 } 137 b_zero_point(uint8_t b_zero_point)138 inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) { 139 this->b_zero_point_ = b_zero_point; 140 return *this; 141 } 142 b_zero_point()143 inline uint8_t b_zero_point() const { 144 return this->b_zero_point_; 145 } 146 qmin(uint8_t qmin)147 inline GemmMicrokernelTester& qmin(uint8_t qmin) { 148 this->qmin_ = qmin; 149 return *this; 150 } 151 qmin()152 inline uint8_t qmin() const { 153 return this->qmin_; 154 } 155 qmax(uint8_t qmax)156 inline GemmMicrokernelTester& qmax(uint8_t qmax) { 157 this->qmax_ = qmax; 158 return *this; 159 } 160 qmax()161 inline uint8_t qmax() const { 162 return this->qmax_; 163 } 164 a_offset(size_t a_offset)165 inline GemmMicrokernelTester& a_offset(size_t a_offset) { 166 this->a_offset_ = a_offset; 167 return *this; 168 } 169 a_offset()170 inline size_t a_offset() const { 171 return this->a_offset_; 172 } 173 zero_index(size_t zero_index)174 inline GemmMicrokernelTester& zero_index(size_t zero_index) { 175 this->zero_index_ = zero_index; 176 return *this; 177 } 178 zero_index()179 inline size_t zero_index() const { 180 return this->zero_index_; 181 } 182 extended_weights(bool extended_weights)183 inline GemmMicrokernelTester& extended_weights(bool extended_weights) { 184 this->extended_weights_ = extended_weights; 185 return *this; 186 } 187 extended_weights()188 inline bool extended_weights() const { 189 return this->extended_weights_; 190 } 191 iterations(size_t iterations)192 inline GemmMicrokernelTester& iterations(size_t iterations) { 193 this->iterations_ = iterations; 194 return *this; 195 } 196 iterations()197 inline size_t iterations() const { 198 return this->iterations_; 199 } 200 201 void Test( 202 xnn_qu8_gemm_minmax_ukernel_function gemm, 203 xnn_init_qu8_conv_minmax_params_fn init_params, 204 xnn_qu8_requantize_fn requantize) const; 205 206 void Test( 207 xnn_qu8_igemm_minmax_ukernel_function igemm, 208 xnn_init_qu8_conv_minmax_params_fn init_params, 209 xnn_qu8_requantize_fn requantize); 210 211 void Test( 212 xnn_qc8_gemm_minmax_ukernel_function gemm, 213 xnn_init_qs8_minmax_params_fn init_params, 214 xnn_qs8_requantize_fn requantize) const; 215 216 void Test( 217 xnn_qc8_igemm_minmax_ukernel_function igemm, 218 xnn_init_qs8_minmax_params_fn init_params, 219 xnn_qs8_requantize_fn requantize) const; 220 221 void Test( 222 xnn_qs8_gemm_minmax_ukernel_function gemm, 223 xnn_init_qs8_conv_minmax_params_fn init_params, 224 xnn_qs8_requantize_fn requantize) const; 225 226 void Test( 227 xnn_qs8_igemm_minmax_ukernel_function igemm, 228 xnn_init_qs8_conv_minmax_params_fn init_params, 229 xnn_qs8_requantize_fn requantize) const; 230 231 void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const; 232 233 void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const; 234 235 void Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 236 237 void Test(xnn_f32_gemm_ukernel_function gemm) const; 238 239 void Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const; 240 241 void Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 242 243 void Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const; 244 245 void Test(xnn_f32_igemm_ukernel_function igemm) const; 246 247 void Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const; 248 249 void Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 250 251 #if XNN_PLATFORM_JIT 252 void Test( 253 xnn_jit_gemm_code_generator_function gemm_generator, 254 xnn_init_f32_minmax_params_fn init_params) const; 255 void Test( 256 xnn_jit_igemm_code_generator_function igemm_generator, 257 xnn_init_f32_minmax_params_fn init_params) const; 258 void Test( 259 xnn_jit_gemm_code_generator_function gemm_generator, 260 xnn_init_qs8_minmax_params_fn init_params, 261 xnn_qs8_requantize_fn requantize) const; 262 void Test( 263 xnn_jit_igemm_code_generator_function igemm_generator, 264 xnn_init_qs8_minmax_params_fn init_params, 265 xnn_qs8_requantize_fn requantize) const; 266 void Test( 267 xnn_jit_gemm_code_generator_function gemm_generator, 268 xnn_init_qs8_conv_minmax_params_fn init_params, 269 xnn_qs8_requantize_fn requantize) const; 270 void Test( 271 xnn_jit_igemm_code_generator_function igemm_generator, 272 xnn_init_qs8_conv_minmax_params_fn init_params, 273 xnn_qs8_requantize_fn requantize) const; 274 #endif // XNN_PLATFORM_JIT 275 276 private: 277 size_t mr_{1}; 278 size_t nr_{1}; 279 size_t kr_{1}; 280 size_t sr_{1}; 281 size_t m_{1}; 282 size_t n_{1}; 283 size_t k_{1}; 284 size_t ks_{1}; 285 size_t a_stride_{0}; 286 size_t cm_stride_{0}; 287 size_t cn_stride_{0}; 288 uint8_t a_zero_point_{127}; 289 uint8_t b_zero_point_{127}; 290 uint8_t qmin_{0}; 291 uint8_t qmax_{255}; 292 size_t a_offset_{0}; 293 size_t zero_index_{SIZE_MAX}; 294 bool extended_weights_{false}; 295 size_t iterations_{15}; 296 }; 297