• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cmath>
16 #include <cstddef>
17 #include <cstdlib>
18 #include <functional>
19 #include <random>
20 #include <vector>
21 
22 #include <fp16.h>
23 
24 #include <xnnpack.h>
25 #include <xnnpack/AlignedAllocator.h>
26 #include <xnnpack/pack.h>
27 #include <xnnpack/params-init.h>
28 #include <xnnpack/params.h>
29 #include <xnnpack/requantization.h>
30 
31 
32 class GemmMicrokernelTester {
33  public:
34   enum class Variant {
35     Native,
36     Scalar,
37   };
38 
mr(size_t mr)39   inline GemmMicrokernelTester& mr(size_t mr) {
40     this->mr_ = mr;
41     return *this;
42   }
43 
mr()44   inline size_t mr() const {
45     return this->mr_;
46   }
47 
nr(size_t nr)48   inline GemmMicrokernelTester& nr(size_t nr) {
49     this->nr_ = nr;
50     return *this;
51   }
52 
nr()53   inline size_t nr() const {
54     return this->nr_;
55   }
56 
57 
kr(size_t kr)58   inline GemmMicrokernelTester& kr(size_t kr) {
59     this->kr_ = kr;
60     return *this;
61   }
62 
kr()63   inline size_t kr() const {
64     return this->kr_;
65   }
66 
sr(size_t sr)67   inline GemmMicrokernelTester& sr(size_t sr) {
68     this->sr_ = sr;
69     return *this;
70   }
71 
sr()72   inline size_t sr() const {
73     return this->sr_;
74   }
75 
m(size_t m)76   inline GemmMicrokernelTester& m(size_t m) {
77     this->m_ = m;
78     return *this;
79   }
80 
m()81   inline size_t m() const {
82     return this->m_;
83   }
84 
n(size_t n)85   inline GemmMicrokernelTester& n(size_t n) {
86     this->n_ = n;
87     return *this;
88   }
89 
n()90   inline size_t n() const {
91     return this->n_;
92   }
93 
k(size_t k)94   inline GemmMicrokernelTester& k(size_t k) {
95     this->k_ = k;
96     return *this;
97   }
98 
k()99   inline size_t k() const {
100     return this->k_;
101   }
102 
ks(size_t ks)103   inline GemmMicrokernelTester& ks(size_t ks) {
104     this->ks_ = ks;
105     return *this;
106   }
107 
ks()108   inline size_t ks() const {
109     return this->ks_;
110   }
111 
packed_k()112   inline size_t packed_k() const {
113     return k() % kr() == 0 ? k() : (k() / kr() + 1) * kr();
114   }
115 
packed_n()116   inline size_t packed_n() const {
117     return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
118   }
119 
bias_n()120   inline size_t bias_n() const {
121     return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
122   }
123 
a_stride(size_t a_stride)124   inline GemmMicrokernelTester& a_stride(size_t a_stride) {
125     this->a_stride_ = a_stride;
126     return *this;
127   }
128 
a_stride()129   inline size_t a_stride() const {
130     return this->a_stride_ == 0 ? k() : this->a_stride_;
131   }
132 
cm_stride(size_t cm_stride)133   inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
134     this->cm_stride_ = cm_stride;
135     return *this;
136   }
137 
cm_stride()138   inline size_t cm_stride() const {
139     return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
140   }
141 
cn_stride(size_t cn_stride)142   inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
143     this->cn_stride_ = cn_stride;
144     return *this;
145   }
146 
cn_stride()147   inline size_t cn_stride() const {
148     return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
149   }
150 
a_zero_point(uint8_t a_zero_point)151   inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
152     this->a_zero_point_ = a_zero_point;
153     return *this;
154   }
155 
a_zero_point()156   inline uint8_t a_zero_point() const {
157     return this->a_zero_point_;
158   }
159 
b_zero_point(uint8_t b_zero_point)160   inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
161     this->b_zero_point_ = b_zero_point;
162     return *this;
163   }
164 
b_zero_point()165   inline uint8_t b_zero_point() const {
166     return this->b_zero_point_;
167   }
168 
qmin(uint8_t qmin)169   inline GemmMicrokernelTester& qmin(uint8_t qmin) {
170     this->qmin_ = qmin;
171     return *this;
172   }
173 
qmin()174   inline uint8_t qmin() const {
175     return this->qmin_;
176   }
177 
qmax(uint8_t qmax)178   inline GemmMicrokernelTester& qmax(uint8_t qmax) {
179     this->qmax_ = qmax;
180     return *this;
181   }
182 
qmax()183   inline uint8_t qmax() const {
184     return this->qmax_;
185   }
186 
a_offset(size_t a_offset)187   inline GemmMicrokernelTester& a_offset(size_t a_offset) {
188     this->a_offset_ = a_offset;
189     return *this;
190   }
191 
a_offset()192   inline size_t a_offset() const {
193     return this->a_offset_;
194   }
195 
zero_index(size_t zero_index)196   inline GemmMicrokernelTester& zero_index(size_t zero_index) {
197     this->zero_index_ = zero_index;
198     return *this;
199   }
200 
zero_index()201   inline size_t zero_index() const {
202     return this->zero_index_;
203   }
204 
iterations(size_t iterations)205   inline GemmMicrokernelTester& iterations(size_t iterations) {
206     this->iterations_ = iterations;
207     return *this;
208   }
209 
iterations()210   inline size_t iterations() const {
211     return this->iterations_;
212   }
213 
214   void Test(xnn_q8_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
215     ASSERT_LE(m(), mr());
216 
217     std::random_device random_device;
218     auto rng = std::mt19937(random_device());
219     auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
220     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
221 
222     std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
223     std::vector<uint8_t> b(n() * k());
224     std::vector<int32_t> bias(n());
225     std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
226     std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
227     std::vector<int32_t> acc(m() * n());
228     std::vector<uint8_t> c_ref(m() * n());
229 
230     for (size_t iteration = 0; iteration < iterations(); iteration++) {
231       do {
232         std::generate(a.begin(), a.end(), std::ref(u8rng));
233       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
234       do {
235         std::generate(b.begin(), b.end(), std::ref(u8rng));
236       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
237       std::generate(bias.begin(), bias.end(), std::ref(s32rng));
238       std::fill(c.begin(), c.end(), 0xA5);
239 
240       std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
241       xnn_pack_q8_gemm_goi_w(1, n(), k(), nr(), kr(),
242         a_zero_point(), b_zero_point(),
243         b.data(), bias.data(), packed_w.data());
244 
245       // Compute 32-bit results and output quantization arguments.
246       std::fill(acc.begin(), acc.end(), 0);
247       for (size_t m_index = 0; m_index < m(); m_index++) {
248         for (size_t n_index = 0; n_index < n(); n_index++) {
249           for (size_t k_index = 0; k_index < k(); k_index++) {
250             acc[m_index * n() + n_index] +=
251                 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
252                 (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
253           }
254           acc[m_index * n() + n_index] += bias[n_index];
255         }
256       }
257 
258       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
259       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
260       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
261       const uint8_t c_zero_point = uint8_t(std::max(std::min(
262         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
263         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
264 
265       const float requantization_scale = 1.0f / float(c_scale);
266       union xnn_q8_gemm_params quantization_params = { };
267       switch (variant) {
268         case Variant::Native:
269           quantization_params = xnn_init_q8_gemm_params(
270             a_zero_point(), b_zero_point(),
271             requantization_scale, c_zero_point, qmin(), qmax());
272           break;
273         case Variant::Scalar:
274           quantization_params = xnn_init_scalar_q8_gemm_params(
275             a_zero_point(), b_zero_point(),
276             requantization_scale, c_zero_point, qmin(), qmax());
277           break;
278       }
279       const union xnn_q31_requantization_params scalar_requantization_params =
280         xnn_init_scalar_requantization_params(
281           requantization_scale, c_zero_point, qmin(), qmax());
282 
283       gemm(
284         m(), n(), k(),
285         a.data(), a_stride() * sizeof(uint8_t),
286         packed_w.data(),
287         c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
288         &quantization_params);
289 
290       for (size_t m_index = 0; m_index < m(); m_index++) {
291         for (size_t n_index = 0; n_index < n(); n_index++) {
292           c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
293         }
294       }
295 
296       for (size_t i = 0; i < m(); i++) {
297         for (size_t j = 0; j < n(); j++) {
298           ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
299           ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
300           ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
301               << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
302               << " (accumulator = " << acc[i * n() + j]
303               << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
304               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
305               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
306         }
307       }
308     }
309   }
310 
311   void Test(xnn_q8_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
312     ASSERT_LE(m(), mr());
313 
314     std::random_device random_device;
315     auto rng = std::mt19937(random_device());
316     auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
317     auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
318 
319     std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
320     std::vector<uint8_t> b(n() * ks() * k());
321     std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
322     std::vector<int32_t> bias(n());
323     std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
324     std::vector<int32_t> acc(m() * n());
325     std::vector<uint8_t> c_ref(m() * n());
326     std::vector<uint8_t> junk(k() + 8);
327     std::vector<const uint8_t*> im2col(mr() * ks());
328 
329     std::fill(junk.begin(), junk.end(), 0xA5);
330 
331     for (size_t iteration = 0; iteration < iterations(); iteration++) {
332       do {
333         std::generate(a.begin(), a.end(), std::ref(u8rng));
334       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
335       do {
336         std::generate(b.begin(), b.end(), std::ref(u8rng));
337       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
338       std::generate(bias.begin(), bias.end(), std::ref(s32rng));
339       std::fill(c.begin(), c.end(), 0xA5);
340 
341       std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
342       xnn_pack_q8_conv_goki_w(
343         1, n(), ks(), k(), nr(), kr(),
344         a_zero_point(), b_zero_point(),
345         b.data(), bias.data(), packed_w.data());
346 
347       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
348         for (size_t m_index = 0; m_index < mr(); m_index++) {
349           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
350         }
351 
352       }
353       std::shuffle(im2col.begin(), im2col.end(), rng);
354       if (zero_index() != SIZE_MAX) {
355         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
356           im2col[ks_index * mr() + zero_index()] = a.data();
357         }
358       }
359       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
360         for (size_t m_index = m(); m_index < mr(); m_index++) {
361           im2col[ks_index * mr() + m_index] = junk.data();
362         }
363       }
364 
365       // Compute 32-bit results and output quantization arguments.
366       std::fill(acc.begin(), acc.end(), 0);
367       for (size_t m_index = 0; m_index < m(); m_index++) {
368         for (size_t n_index = 0; n_index < n(); n_index++) {
369           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
370             for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
371               for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
372                 if (im2col[ks_index * mr() + m_index] == a.data()) {
373                   acc[m_index * n() + n_index] +=
374                     (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) - int32_t(a_zero_point())) *
375                     (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
376                 } else {
377                   acc[m_index * n() + n_index] +=
378                     (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) - int32_t(a_zero_point())) *
379                     (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
380                 }
381               }
382             }
383           }
384           acc[m_index * n() + n_index] += bias[n_index];
385         }
386       }
387 
388       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
389       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
390       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
391       const uint8_t c_zero_point = uint8_t(std::max(std::min(
392         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
393         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
394 
395       const float requantization_scale = 1.0f / float(c_scale);
396       union xnn_q8_gemm_params quantization_params = { };
397       switch (variant) {
398         case Variant::Native:
399           quantization_params = xnn_init_q8_gemm_params(
400             a_zero_point(), b_zero_point(),
401             requantization_scale, c_zero_point, qmin(), qmax());
402           break;
403         case Variant::Scalar:
404           quantization_params = xnn_init_scalar_q8_gemm_params(
405             a_zero_point(), b_zero_point(),
406             requantization_scale, c_zero_point, qmin(), qmax());
407           break;
408       }
409       const union xnn_q31_requantization_params scalar_requantization_params =
410         xnn_init_scalar_requantization_params(
411           requantization_scale, c_zero_point, qmin(), qmax());
412 
413       const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
414 
415       igemm(
416         m(), n(), k(), ks() * mr() * sizeof(void*),
417         im2col.data(), packed_w.data(),
418         c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
419         a_offset() * sizeof(uint8_t), zero_pointer,
420         &quantization_params);
421 
422       for (size_t m_index = 0; m_index < m(); m_index++) {
423         for (size_t n_index = 0; n_index < n(); n_index++) {
424           c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
425         }
426       }
427 
428       for (size_t i = 0; i < m(); i++) {
429         for (size_t j = 0; j < n(); j++) {
430           ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
431           ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
432           ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
433               << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
434               << " (accumulator = " << acc[i * n() + j]
435               << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
436               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
437               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
438         }
439       }
440     }
441   }
442 
443   void Test(xnn_f16_gemm_ukernel_function gemm, Variant variant = Variant::Native) const
444   {
445     ASSERT_LE(m(), mr());
446     ASSERT_GE(a_stride(), k());
447     ASSERT_GE(cm_stride(), n());
448 
449     std::random_device random_device;
450     auto rng = std::mt19937(random_device());
451     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
452     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
453 
454     std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
455     std::vector<uint16_t> b(n() * k());
456     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + bias_n());
457     std::vector<uint16_t> bias(n());
458     std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
459     std::vector<float> c_ref(m() * n());
460 
461     xnn_f16_output_params output_params;
462     output_params.scale = UINT16_C(0x3C00) /* 1.0 */;
463 
464     for (size_t iteration = 0; iteration < iterations(); iteration++) {
465       std::generate(a.begin(), a.end(), std::ref(f16rng));
466       std::generate(b.begin(), b.end(), std::ref(f16rng));
467       std::generate(bias.begin(), bias.end(), std::ref(f16rng));
468       std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
469       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
470 
471       std::fill(packed_w.begin(), packed_w.end(), 0);
472       xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), b.data(), bias.data(), packed_w.data());
473 
474       for (size_t m_index = 0; m_index < m(); m_index++) {
475         for (size_t n_index = 0; n_index < n(); n_index++) {
476           for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
477             for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
478               ASSERT_LE(n(), packed_n());
479               ASSERT_LT(m_index * n() + n_index, c_ref.size());
480               ASSERT_LT(m_index * k() + k_block_start + k_block_offset, a.size());
481 
482               c_ref[m_index * n() + n_index] +=
483                 fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_block_start + k_block_offset]) *
484                 fp16_ieee_to_fp32_value(b[n_index * k() + k_block_start + k_block_offset]);
485             }
486           }
487           c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
488         }
489       }
490 
491       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
492       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
493       const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
494       const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
495       output_params.max = fp16_ieee_from_fp32_value(c_max);
496       output_params.min = fp16_ieee_from_fp32_value(c_min);
497 
498       for (float& c_value : c_ref) {
499         c_value = std::max(std::min(c_value, c_max), c_min);
500       }
501 
502       gemm(m(), n(), k() * sizeof(uint16_t),
503         a.data(), a_stride() * sizeof(uint16_t),
504         packed_w.data(),
505         c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
506         &output_params);
507 
508       // Validate micro-kernel outputs.
509       for (size_t i = 0; i < m(); i++) {
510         for (size_t j = 0; j < n(); j++) {
511           ASSERT_NEAR(
512               fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]),
513               c_ref[i * n() + j],
514               std::abs(c_ref[i * n() + j]) * 1.0e-2f)
515               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
516               << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
517               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
518         }
519       }
520     }
521   }
522 
523   void Test(xnn_f32_ppmm_ukernel_function ppmm, Variant variant = Variant::Native) const {
524     ASSERT_LE(m(), mr());
525     ASSERT_GE(cm_stride(), n());
526 
527     std::random_device random_device;
528     auto rng = std::mt19937(random_device());
529     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
530 
531     std::vector<float> a(packed_k() * mr());
532     std::vector<float> b(n() * k());
533     std::vector<float> bias(n());
534     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
535     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
536     std::vector<float> c_ref(m() * n());
537 
538     for (size_t iteration = 0; iteration < iterations(); iteration++) {
539       std::generate(a.begin(), a.end(), std::ref(f32rng));
540       std::generate(b.begin(), b.end(), std::ref(f32rng));
541       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
542       std::fill(c.begin(), c.end(), nanf(""));
543       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
544 
545       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
546       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
547 
548       for (size_t i = m(); i < mr(); i++) {
549         for (size_t l = 0; l < k(); l++) {
550           a[l * mr() + i] = a[l * mr() + m() - 1];
551         }
552       }
553 
554       for (size_t i = 0; i < m(); i++) {
555         for (size_t j = 0; j < n(); j++) {
556           for (size_t l = 0; l < k(); l++) {
557             c_ref[i * n() + j] +=
558               a[l * mr() + i] *
559               b[j * k() + l];
560           }
561           c_ref[i * n() + j] += bias[j];
562         }
563       }
564 
565       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
566       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
567       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
568       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
569 
570       // Prepare output parameters.
571       xnn_f32_output_params output_params = { };
572       switch (variant) {
573         case Variant::Native:
574           output_params = xnn_init_f32_output_params(c_min, c_max);
575           break;
576         case Variant::Scalar:
577           output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
578           break;
579       }
580 
581       for (float& c_value : c_ref) {
582         c_value = std::max(std::min(c_value, c_max), c_min);
583       }
584 
585       ppmm(m(), n(), k() * sizeof(float),
586         a.data(), packed_w.data(),
587         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
588         &output_params);
589 
590       // Validate micro-kernel outputs.
591       for (size_t i = 0; i < m(); i++) {
592         for (size_t j = 0; j < n(); j++) {
593           ASSERT_NEAR(
594               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
595               c_ref[i * n() + j],
596               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
597               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
598               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
599               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
600         }
601       }
602     }
603   }
604 
605   void Test(xnn_f32_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
606     ASSERT_LE(m(), mr());
607     ASSERT_GE(a_stride(), k());
608     ASSERT_GE(cm_stride(), n());
609 
610     std::random_device random_device;
611     auto rng = std::mt19937(random_device());
612     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
613 
614     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
615     std::vector<float> b(n() * k());
616     std::vector<float> bias(n());
617     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
618     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
619     std::vector<float> c_ref(m() * n());
620 
621     for (size_t iteration = 0; iteration < iterations(); iteration++) {
622       std::generate(a.begin(), a.end(), std::ref(f32rng));
623       std::generate(b.begin(), b.end(), std::ref(f32rng));
624       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
625       std::fill(c.begin(), c.end(), nanf(""));
626       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
627 
628       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
629       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
630 
631       for (size_t m_index = 0; m_index < m(); m_index++) {
632         for (size_t n_index = 0; n_index < n(); n_index++) {
633           for (size_t k_index = 0; k_index < k(); k_index++) {
634             ASSERT_LE(n(), packed_n());
635             ASSERT_LT(m_index * n() + n_index, c_ref.size());
636             c_ref[m_index * n() + n_index] +=
637               a[m_index * a_stride() + k_index] *
638               b[n_index * k() + k_index];
639           }
640           c_ref[m_index * n() + n_index] += bias[n_index];
641         }
642       }
643 
644       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
645       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
646       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
647       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
648 
649       // Prepare output parameters.
650       xnn_f32_output_params output_params = { };
651       switch (variant) {
652         case Variant::Native:
653           output_params = xnn_init_f32_output_params(c_min, c_max);
654           break;
655         case Variant::Scalar:
656           output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
657           break;
658       }
659 
660       for (size_t m_index = 0; m_index < m(); m_index++) {
661         for (size_t n_index = 0; n_index < n(); n_index++) {
662           c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
663         }
664       }
665 
666       gemm(m(), n(), k() * sizeof(float),
667         a.data(), a_stride() * sizeof(float),
668         packed_w.data(),
669         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
670         &output_params);
671 
672       // Validate micro-kernel outputs.
673       for (size_t i = 0; i < m(); i++) {
674         for (size_t j = 0; j < n(); j++) {
675           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
676               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
677               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
678               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
679           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
680               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
681               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
682               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
683           ASSERT_NEAR(
684               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
685               c_ref[i * n() + j],
686               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
687               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
688               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
689               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
690         }
691       }
692     }
693   }
694 
695   void Test(xnn_f32_gemminc_ukernel_function gemminc, Variant variant = Variant::Native) const {
696     ASSERT_LE(m(), mr());
697     ASSERT_GE(a_stride(), k());
698     ASSERT_GE(cm_stride(), n());
699 
700     std::random_device random_device;
701     auto rng = std::mt19937(random_device());
702     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
703 
704     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
705     std::vector<float> b(n() * k());
706     std::vector<float> bias(n());
707     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k());  // no bias_n()
708     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
709     std::vector<float> c_ref(m() * n());
710     std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
711 
712     for (size_t iteration = 0; iteration < iterations(); iteration++) {
713       std::generate(a.begin(), a.end(), std::ref(f32rng));
714       std::generate(b.begin(), b.end(), std::ref(f32rng));
715       std::fill(c.begin(), c.end(), nanf(""));
716       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
717       std::generate(acc.begin(), acc.end(), std::ref(f32rng));
718 
719       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
720       xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data());
721 
722       for (size_t m_index = 0; m_index < m(); m_index++) {
723         for (size_t n_index = 0; n_index < n(); n_index++) {
724           for (size_t k_index = 0; k_index < k(); k_index++) {
725             ASSERT_LE(n(), packed_n());
726             ASSERT_LT(m_index * n() + n_index, c_ref.size());
727             c_ref[m_index * n() + n_index] +=
728               a[m_index * a_stride() + k_index] *
729               b[n_index * k() + k_index];
730           }
731           c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
732         }
733       }
734 
735       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
736       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
737       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
738       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
739 
740       // Prepare output parameters.
741       xnn_f32_output_params output_params = { };
742       switch (variant) {
743         case Variant::Native:
744           output_params = xnn_init_f32_output_params(c_min, c_max);
745           break;
746         case Variant::Scalar:
747           output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
748           break;
749       }
750 
751       for (size_t m_index = 0; m_index < m(); m_index++) {
752         for (size_t n_index = 0; n_index < n(); n_index++) {
753           c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
754         }
755       }
756 
757       gemminc(m(), n(), k() * sizeof(float),
758         a.data(), a_stride() * sizeof(float),
759         packed_w.data(),
760         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
761         acc.data(),
762         &output_params);
763 
764       // Validate micro-kernel outputs.
765       for (size_t i = 0; i < m(); i++) {
766         for (size_t j = 0; j < n(); j++) {
767           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
768               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
769               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
770               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
771           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
772               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
773               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
774               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
775           ASSERT_NEAR(
776               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
777               c_ref[i * n() + j],
778               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
779               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
780               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
781               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
782         }
783       }
784     }
785   }
786 
787   void Test(xnn_f32_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
788     ASSERT_LE(m(), mr());
789 
790     std::random_device random_device;
791     auto rng = std::mt19937(random_device());
792     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
793 
794     std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
795     std::vector<float> b(n() * ks() * k());
796     std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
797     std::vector<float> bias(n());
798     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
799     std::vector<float> c_ref(m() * n());
800     std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
801     std::vector<const float*> im2col(mr() * ks());
802     std::fill(junk.begin(), junk.end(), nanf(""));
803 
804     for (size_t iteration = 0; iteration < iterations(); iteration++) {
805       std::generate(a.begin(), a.end(), std::ref(f32rng));
806       std::generate(b.begin(), b.end(), std::ref(f32rng));
807       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
808       std::fill(c.begin(), c.end(), nanf(""));
809       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
810 
811       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
812       xnn_pack_f32_conv_goki_w(
813         1, n(), ks(), k(), nr(), kr(), sr(),
814         b.data(), bias.data(), packed_w.data());
815 
816       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
817         for (size_t m_index = 0; m_index < mr(); m_index++) {
818           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
819         }
820       }
821       std::shuffle(im2col.begin(), im2col.end(), rng);
822       if (zero_index() != SIZE_MAX) {
823         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
824           im2col[ks_index * mr() + zero_index()] = a.data();
825         }
826       }
827       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
828         for (size_t m_index = m(); m_index < mr(); m_index++) {
829           im2col[ks_index * mr() + m_index] = junk.data();
830         }
831       }
832 
833       std::fill(c_ref.begin(), c_ref.end(), 0.0);
834       for (size_t m_index = 0; m_index < m(); m_index++) {
835         for (size_t n_index = 0; n_index < n(); n_index++) {
836           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
837             for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
838               for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
839                 ASSERT_LT(ks_index * mr() + m_index, im2col.size());
840                 ASSERT_LT(k_block_start + k_block_offset, k());
841                 ASSERT_LT(k_block_start + k_block_offset, a_stride());
842                 if (im2col[ks_index * mr() + m_index] == a.data()) {
843                   c_ref[m_index * n() + n_index] +=
844                     double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) *
845                     double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
846                 } else {
847                   c_ref[m_index * n() + n_index] +=
848                     double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) *
849                     double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
850                 }
851               }
852             }
853           }
854           c_ref[m_index * n() + n_index] += bias[n_index];
855         }
856       }
857 
858       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
859       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
860       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
861       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
862       for (size_t m_index = 0; m_index < m(); m_index++) {
863         for (size_t n_index = 0; n_index < n(); n_index++) {
864           c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
865           c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
866         }
867       }
868 
869       // Prepare output parameters.
870       xnn_f32_output_params output_params = { };
871       switch (variant) {
872         case Variant::Native:
873           output_params = xnn_init_f32_output_params(c_min, c_max);
874           break;
875         case Variant::Scalar:
876           output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
877           break;
878       }
879 
880       const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
881 
882       igemm(
883         m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
884         im2col.data(), packed_w.data(),
885         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
886         a_offset() * sizeof(float), zero_pointer,
887         &output_params);
888 
889       for (size_t i = 0; i < m(); i++) {
890         for (size_t j = 0; j < n(); j++) {
891           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
892               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
893               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
894               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
895           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
896               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
897               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
898               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
899           ASSERT_NEAR(
900               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
901               c_ref[i * n() + j],
902               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
903               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
904               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
905               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
906         }
907       }
908     }
909   }
910 
911  private:
912   size_t mr_{1};
913   size_t nr_{1};
914   size_t kr_{1};
915   size_t sr_{1};
916   size_t m_{1};
917   size_t n_{1};
918   size_t k_{1};
919   size_t ks_{1};
920   size_t a_stride_{0};
921   size_t cm_stride_{0};
922   size_t cn_stride_{0};
923   uint8_t a_zero_point_{127};
924   uint8_t b_zero_point_{127};
925   uint8_t qmin_{0};
926   uint8_t qmax_{255};
927   size_t a_offset_{0};
928   size_t zero_index_{SIZE_MAX};
929   size_t iterations_{15};
930 };
931