• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <cassert>
15 #include <cmath>
16 #include <cstddef>
17 #include <cstdlib>
18 #include <functional>
19 #include <limits>
20 #include <random>
21 #include <vector>
22 
23 #include <fp16.h>
24 
25 #include <xnnpack.h>
26 #include <xnnpack/AlignedAllocator.h>
27 #include <xnnpack/pack.h>
28 #include <xnnpack/params-init.h>
29 #include <xnnpack/params.h>
30 #include <xnnpack/requantization.h>
31 
32 
33 class GemmMicrokernelTester {
34  public:
35   enum class Variant {
36     Native,
37     Scalar,
38   };
39 
mr(size_t mr)40   inline GemmMicrokernelTester& mr(size_t mr) {
41     this->mr_ = mr;
42     return *this;
43   }
44 
mr()45   inline size_t mr() const {
46     return this->mr_;
47   }
48 
nr(size_t nr)49   inline GemmMicrokernelTester& nr(size_t nr) {
50     this->nr_ = nr;
51     return *this;
52   }
53 
nr()54   inline size_t nr() const {
55     return this->nr_;
56   }
57 
58 
kr(size_t kr)59   inline GemmMicrokernelTester& kr(size_t kr) {
60     this->kr_ = kr;
61     return *this;
62   }
63 
kr()64   inline size_t kr() const {
65     return this->kr_;
66   }
67 
sr(size_t sr)68   inline GemmMicrokernelTester& sr(size_t sr) {
69     this->sr_ = sr;
70     return *this;
71   }
72 
sr()73   inline size_t sr() const {
74     return this->sr_;
75   }
76 
m(size_t m)77   inline GemmMicrokernelTester& m(size_t m) {
78     this->m_ = m;
79     return *this;
80   }
81 
m()82   inline size_t m() const {
83     return this->m_;
84   }
85 
n(size_t n)86   inline GemmMicrokernelTester& n(size_t n) {
87     this->n_ = n;
88     return *this;
89   }
90 
n()91   inline size_t n() const {
92     return this->n_;
93   }
94 
k(size_t k)95   inline GemmMicrokernelTester& k(size_t k) {
96     this->k_ = k;
97     return *this;
98   }
99 
k()100   inline size_t k() const {
101     return this->k_;
102   }
103 
ks(size_t ks)104   inline GemmMicrokernelTester& ks(size_t ks) {
105     this->ks_ = ks;
106     return *this;
107   }
108 
ks()109   inline size_t ks() const {
110     return this->ks_;
111   }
112 
packed_k()113   inline size_t packed_k() const {
114     return k() % kr() == 0 ? k() : (k() / kr() + 1) * kr();
115   }
116 
packed_n()117   inline size_t packed_n() const {
118     return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
119   }
120 
bias_n()121   inline size_t bias_n() const {
122     return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
123   }
124 
a_stride(size_t a_stride)125   inline GemmMicrokernelTester& a_stride(size_t a_stride) {
126     this->a_stride_ = a_stride;
127     return *this;
128   }
129 
a_stride()130   inline size_t a_stride() const {
131     return this->a_stride_ == 0 ? k() : this->a_stride_;
132   }
133 
cm_stride(size_t cm_stride)134   inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
135     this->cm_stride_ = cm_stride;
136     return *this;
137   }
138 
cm_stride()139   inline size_t cm_stride() const {
140     return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
141   }
142 
cn_stride(size_t cn_stride)143   inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
144     this->cn_stride_ = cn_stride;
145     return *this;
146   }
147 
cn_stride()148   inline size_t cn_stride() const {
149     return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
150   }
151 
a_zero_point(uint8_t a_zero_point)152   inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
153     this->a_zero_point_ = a_zero_point;
154     return *this;
155   }
156 
a_zero_point()157   inline uint8_t a_zero_point() const {
158     return this->a_zero_point_;
159   }
160 
b_zero_point(uint8_t b_zero_point)161   inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
162     this->b_zero_point_ = b_zero_point;
163     return *this;
164   }
165 
b_zero_point()166   inline uint8_t b_zero_point() const {
167     return this->b_zero_point_;
168   }
169 
qmin(uint8_t qmin)170   inline GemmMicrokernelTester& qmin(uint8_t qmin) {
171     this->qmin_ = qmin;
172     return *this;
173   }
174 
qmin()175   inline uint8_t qmin() const {
176     return this->qmin_;
177   }
178 
qmax(uint8_t qmax)179   inline GemmMicrokernelTester& qmax(uint8_t qmax) {
180     this->qmax_ = qmax;
181     return *this;
182   }
183 
qmax()184   inline uint8_t qmax() const {
185     return this->qmax_;
186   }
187 
a_offset(size_t a_offset)188   inline GemmMicrokernelTester& a_offset(size_t a_offset) {
189     this->a_offset_ = a_offset;
190     return *this;
191   }
192 
a_offset()193   inline size_t a_offset() const {
194     return this->a_offset_;
195   }
196 
zero_index(size_t zero_index)197   inline GemmMicrokernelTester& zero_index(size_t zero_index) {
198     this->zero_index_ = zero_index;
199     return *this;
200   }
201 
zero_index()202   inline size_t zero_index() const {
203     return this->zero_index_;
204   }
205 
iterations(size_t iterations)206   inline GemmMicrokernelTester& iterations(size_t iterations) {
207     this->iterations_ = iterations;
208     return *this;
209   }
210 
iterations()211   inline size_t iterations() const {
212     return this->iterations_;
213   }
214 
215   void Test(xnn_qu8_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
216     ASSERT_LE(m(), mr());
217 
218     std::random_device random_device;
219     auto rng = std::mt19937(random_device());
220     auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
221     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
222 
223     std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
224     std::vector<uint8_t> b(n() * k());
225     std::vector<int32_t> bias(n());
226     std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(int32_t) / sizeof(uint8_t));
227     std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
228     std::vector<int32_t> acc(m() * n());
229     std::vector<uint8_t> c_ref(m() * n());
230 
231     for (size_t iteration = 0; iteration < iterations(); iteration++) {
232       do {
233         std::generate(a.begin(), a.end(), std::ref(u8rng));
234       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
235       do {
236         std::generate(b.begin(), b.end(), std::ref(u8rng));
237       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
238       std::generate(bias.begin(), bias.end(), std::ref(i32rng));
239       std::fill(c.begin(), c.end(), 0xA5);
240 
241       std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
242       const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
243       xnn_pack_qu8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
244         b.data(), bias.data(), packed_w.data(), &packing_params);
245 
246       // Compute 32-bit results and output quantization arguments.
247       std::fill(acc.begin(), acc.end(), 0);
248       for (size_t m_index = 0; m_index < m(); m_index++) {
249         for (size_t n_index = 0; n_index < n(); n_index++) {
250           for (size_t k_index = 0; k_index < k(); k_index++) {
251             acc[m_index * n() + n_index] +=
252                 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
253                 (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
254           }
255           acc[m_index * n() + n_index] += bias[n_index];
256         }
257       }
258 
259       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
260       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
261       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
262       const uint8_t c_zero_point = uint8_t(std::max(std::min(
263         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
264         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
265 
266       const float requantization_scale = 1.0f / float(c_scale);
267       union xnn_qu8_gemm_params quantization_params = { };
268       switch (variant) {
269         case Variant::Native:
270           quantization_params = xnn_init_qu8_gemm_params(
271             b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
272           break;
273         case Variant::Scalar:
274           quantization_params = xnn_init_scalar_qu8_gemm_params(
275             b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
276           break;
277       }
278       const union xnn_qu8_requantization_params scalar_requantization_params =
279         xnn_init_scalar_qu8_requantization_params(requantization_scale, c_zero_point, qmin(), qmax());
280 
281       gemm(
282         m(), n(), k(),
283         a.data(), a_stride() * sizeof(uint8_t),
284         packed_w.data(),
285         c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
286         &quantization_params);
287 
288       for (size_t m_index = 0; m_index < m(); m_index++) {
289         for (size_t n_index = 0; n_index < n(); n_index++) {
290           c_ref[m_index * n() + n_index] = xnn_qu8_requantize_q31(acc[m_index * n() + n_index], scalar_requantization_params);
291         }
292       }
293 
294       for (size_t i = 0; i < m(); i++) {
295         for (size_t j = 0; j < n(); j++) {
296           ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
297           ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
298           ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
299               << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
300               << " (accumulator = " << acc[i * n() + j]
301               << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
302               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
303               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
304         }
305       }
306     }
307   }
308 
309   void Test(xnn_qu8_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
310     ASSERT_LE(m(), mr());
311 
312     std::random_device random_device;
313     auto rng = std::mt19937(random_device());
314     auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
315     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
316 
317     std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
318     std::vector<uint8_t> b(n() * ks() * k());
319     std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + bias_n() * sizeof(int32_t) / sizeof(uint8_t));
320     std::vector<int32_t> bias(n());
321     std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
322     std::vector<int32_t> acc(m() * n());
323     std::vector<uint8_t> c_ref(m() * n());
324     std::vector<uint8_t> junk(k() + 8);
325     std::vector<const uint8_t*> im2col(mr() * ks());
326 
327     std::fill(junk.begin(), junk.end(), 0xA5);
328 
329     for (size_t iteration = 0; iteration < iterations(); iteration++) {
330       do {
331         std::generate(a.begin(), a.end(), std::ref(u8rng));
332       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
333       do {
334         std::generate(b.begin(), b.end(), std::ref(u8rng));
335       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
336       std::generate(bias.begin(), bias.end(), std::ref(i32rng));
337       std::fill(c.begin(), c.end(), 0xA5);
338 
339       std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
340       const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
341       xnn_pack_qu8_conv_goki_w(
342         1, n(), ks(), k(), nr(), kr(), sr(),
343         b.data(), bias.data(), packed_w.data(), &packing_params);
344 
345       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
346         for (size_t m_index = 0; m_index < mr(); m_index++) {
347           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
348         }
349 
350       }
351       std::shuffle(im2col.begin(), im2col.end(), rng);
352       if (zero_index() != SIZE_MAX) {
353         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
354           im2col[ks_index * mr() + zero_index()] = a.data();
355         }
356       }
357       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
358         for (size_t m_index = m(); m_index < mr(); m_index++) {
359           im2col[ks_index * mr() + m_index] = junk.data();
360         }
361       }
362 
363       // Compute 32-bit results and output quantization arguments.
364       std::fill(acc.begin(), acc.end(), 0);
365       for (size_t m_index = 0; m_index < m(); m_index++) {
366         for (size_t n_index = 0; n_index < n(); n_index++) {
367           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
368             for (size_t k_index = 0; k_index < k(); k_index++) {
369               if (im2col[ks_index * mr() + m_index] == a.data()) {
370                 acc[m_index * n() + n_index] +=
371                   (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point())) *
372                   (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
373               } else {
374                 acc[m_index * n() + n_index] +=
375                   (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point())) *
376                   (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
377               }
378             }
379           }
380           acc[m_index * n() + n_index] += bias[n_index];
381         }
382       }
383 
384       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
385       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
386       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
387       const uint8_t c_zero_point = uint8_t(std::max(std::min(
388         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
389         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
390 
391       const float requantization_scale = 1.0f / float(c_scale);
392       union xnn_qu8_gemm_params quantization_params = { };
393       switch (variant) {
394         case Variant::Native:
395           quantization_params = xnn_init_qu8_gemm_params(
396             b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
397           break;
398         case Variant::Scalar:
399           quantization_params = xnn_init_scalar_qu8_gemm_params(
400             b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
401           break;
402       }
403       const union xnn_qu8_requantization_params scalar_requantization_params =
404         xnn_init_scalar_qu8_requantization_params(requantization_scale, c_zero_point, qmin(), qmax());
405 
406       const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
407 
408       igemm(
409         m(), n(), k(), ks() * mr() * sizeof(void*),
410         im2col.data(), packed_w.data(),
411         c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
412         a_offset() * sizeof(uint8_t), zero_pointer,
413         &quantization_params);
414 
415       for (size_t m_index = 0; m_index < m(); m_index++) {
416         for (size_t n_index = 0; n_index < n(); n_index++) {
417           c_ref[m_index * n() + n_index] = xnn_qu8_requantize_q31(acc[m_index * n() + n_index], scalar_requantization_params);
418         }
419       }
420 
421       for (size_t i = 0; i < m(); i++) {
422         for (size_t j = 0; j < n(); j++) {
423           ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
424           ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
425           ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
426               << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
427               << " (accumulator = " << acc[i * n() + j]
428               << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
429               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
430               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
431         }
432       }
433     }
434   }
435 
436   void Test(xnn_qs8_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
437     ASSERT_LE(m(), mr());
438 
439     std::random_device random_device;
440     auto rng = std::mt19937(random_device());
441     auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
442     auto i8rng = std::bind(
443       std::uniform_int_distribution<int32_t>(-127, std::numeric_limits<int8_t>::max()), rng);
444 
445     std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
446     std::vector<int8_t> b(n() * k());
447     std::vector<int32_t> bias(n());
448     std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(int32_t) / sizeof(int8_t));
449     std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
450     std::vector<int32_t> acc(m() * n());
451     std::vector<int8_t> c_ref(m() * n());
452 
453     for (size_t iteration = 0; iteration < iterations(); iteration++) {
454       do {
455         std::generate(a.begin(), a.end(), std::ref(i8rng));
456       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
457       do {
458         std::generate(b.begin(), b.end(), std::ref(i8rng));
459       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
460       std::generate(bias.begin(), bias.end(), std::ref(i32rng));
461       std::fill(c.begin(), c.end(), 0xA5);
462 
463       std::fill(packed_w.begin(), packed_w.end(), 0);
464       const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
465       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
466         b.data(), bias.data(), packed_w.data(), &packing_params);
467 
468       // Compute 32-bit results and output quantization arguments.
469       std::fill(acc.begin(), acc.end(), 0);
470       for (size_t m_index = 0; m_index < m(); m_index++) {
471         for (size_t n_index = 0; n_index < n(); n_index++) {
472           for (size_t k_index = 0; k_index < k(); k_index++) {
473             acc[m_index * n() + n_index] +=
474                 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
475                 int32_t(b[n_index * k() + k_index]);
476           }
477           acc[m_index * n() + n_index] += bias[n_index];
478         }
479       }
480 
481       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
482       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
483       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
484       const int8_t c_zero_point = int8_t(std::max(std::min(
485         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
486         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
487 
488       const float requantization_scale = 1.0f / float(c_scale);
489       union xnn_qs8_gemm_params quantization_params = { };
490       switch (variant) {
491         case Variant::Native:
492           quantization_params = xnn_init_qs8_gemm_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
493           break;
494         case Variant::Scalar:
495           quantization_params = xnn_init_scalar_qs8_gemm_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
496           break;
497       }
498       const union xnn_qs8_requantization_params scalar_requantization_params =
499         xnn_init_scalar_qs8_requantization_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
500 
501       gemm(
502         m(), n(), k(),
503         a.data(), a_stride() * sizeof(int8_t),
504         packed_w.data(),
505         c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
506         &quantization_params);
507 
508       for (size_t m_index = 0; m_index < m(); m_index++) {
509         for (size_t n_index = 0; n_index < n(); n_index++) {
510           c_ref[m_index * n() + n_index] = xnn_qs8_requantize_q31(acc[m_index * n() + n_index], scalar_requantization_params);
511         }
512       }
513 
514       for (size_t i = 0; i < m(); i++) {
515         for (size_t j = 0; j < n(); j++) {
516           ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
517           ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
518           ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
519               << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
520               << " (accumulator = " << acc[i * n() + j]
521               << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
522               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
523               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
524         }
525       }
526     }
527   }
528 
529   void Test(xnn_qs8_gemm_xw_ukernel_function gemm, Variant variant = Variant::Native) const {
530     ASSERT_LE(m(), mr());
531 
532     std::random_device random_device;
533     auto rng = std::mt19937(random_device());
534     auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
535     auto i8rng = std::bind(
536       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), rng);
537 
538     std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
539     std::vector<int8_t> b(n() * k());
540     std::vector<int32_t> bias(n());
541     std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(int32_t) / sizeof(int16_t));
542     std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
543     std::vector<int32_t> acc(m() * n());
544     std::vector<int8_t> c_ref(m() * n());
545 
546     for (size_t iteration = 0; iteration < iterations(); iteration++) {
547       do {
548         std::generate(a.begin(), a.end(), std::ref(i8rng));
549       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
550       do {
551         std::generate(b.begin(), b.end(), std::ref(i8rng));
552       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
553       std::generate(bias.begin(), bias.end(), std::ref(i32rng));
554       std::fill(c.begin(), c.end(), 0xA5);
555 
556       std::fill(packed_w.begin(), packed_w.end(), 0);
557       const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
558       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
559         b.data(), bias.data(), packed_w.data(), &packing_params);
560 
561       // Compute 32-bit results and output quantization arguments.
562       std::fill(acc.begin(), acc.end(), 0);
563       for (size_t m_index = 0; m_index < m(); m_index++) {
564         for (size_t n_index = 0; n_index < n(); n_index++) {
565           for (size_t k_index = 0; k_index < k(); k_index++) {
566             acc[m_index * n() + n_index] +=
567                 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
568                 int32_t(b[n_index * k() + k_index]);
569           }
570           acc[m_index * n() + n_index] += bias[n_index];
571         }
572       }
573 
574       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
575       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
576       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
577       const int8_t c_zero_point = int8_t(std::max(std::min(
578         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
579         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
580 
581       const float requantization_scale = 1.0f / float(c_scale);
582       union xnn_qs8_gemm_xw_params quantization_params = { };
583       switch (variant) {
584         case Variant::Native:
585           quantization_params = xnn_init_qs8_gemm_xw_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
586           break;
587         case Variant::Scalar:
588           quantization_params = xnn_init_scalar_qs8_gemm_xw_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
589           break;
590       }
591       const union xnn_qs8_requantization_params scalar_requantization_params =
592         xnn_init_scalar_qs8_requantization_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
593 
594       gemm(
595         m(), n(), k(),
596         a.data(), a_stride() * sizeof(int8_t),
597         packed_w.data(),
598         c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
599         &quantization_params);
600 
601       for (size_t m_index = 0; m_index < m(); m_index++) {
602         for (size_t n_index = 0; n_index < n(); n_index++) {
603           c_ref[m_index * n() + n_index] = xnn_qs8_requantize_q31(acc[m_index * n() + n_index], scalar_requantization_params);
604         }
605       }
606 
607       for (size_t i = 0; i < m(); i++) {
608         for (size_t j = 0; j < n(); j++) {
609           ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
610           ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
611           ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
612               << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
613               << " (accumulator = " << acc[i * n() + j]
614               << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
615               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
616               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
617         }
618       }
619     }
620   }
621 
622   void Test(xnn_qs8_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
623     ASSERT_LE(m(), mr());
624 
625     std::random_device random_device;
626     auto rng = std::mt19937(random_device());
627     auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
628     auto i8rng = std::bind(
629       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), rng);
630 
631     std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
632     std::vector<int8_t> b(n() * ks() * k());
633     std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + bias_n() * sizeof(int32_t) / sizeof(int8_t));
634     std::vector<int32_t> bias(n());
635     std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
636     std::vector<int32_t> acc(m() * n());
637     std::vector<int8_t> c_ref(m() * n());
638     std::vector<int8_t> junk(k() + 8);
639     std::vector<const int8_t*> im2col(mr() * ks());
640 
641     std::fill(junk.begin(), junk.end(), 0xA5);
642 
643     for (size_t iteration = 0; iteration < iterations(); iteration++) {
644       do {
645         std::generate(a.begin(), a.end(), std::ref(i8rng));
646       } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
647       do {
648         std::generate(b.begin(), b.end(), std::ref(i8rng));
649       } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
650       std::generate(bias.begin(), bias.end(), std::ref(i32rng));
651       std::fill(c.begin(), c.end(), 0xA5);
652 
653       std::fill(packed_w.begin(), packed_w.end(), 0);
654       const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
655       xnn_pack_qs8_conv_goki_w(
656         1, n(), ks(), k(), nr(), kr(), sr(),
657         b.data(), bias.data(), packed_w.data(), &packing_params);
658 
659       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
660         for (size_t m_index = 0; m_index < mr(); m_index++) {
661           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
662         }
663 
664       }
665       std::shuffle(im2col.begin(), im2col.end(), rng);
666       if (zero_index() != SIZE_MAX) {
667         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
668           im2col[ks_index * mr() + zero_index()] = a.data();
669         }
670       }
671       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
672         for (size_t m_index = m(); m_index < mr(); m_index++) {
673           im2col[ks_index * mr() + m_index] = junk.data();
674         }
675       }
676 
677       // Compute 32-bit results and output quantization arguments.
678       std::fill(acc.begin(), acc.end(), 0);
679       for (size_t m_index = 0; m_index < m(); m_index++) {
680         for (size_t n_index = 0; n_index < n(); n_index++) {
681           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
682             for (size_t k_index = 0; k_index < k(); k_index++) {
683               if (im2col[ks_index * mr() + m_index] == a.data()) {
684                 acc[m_index * n() + n_index] +=
685                   (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
686                   int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
687               } else {
688                 acc[m_index * n() + n_index] +=
689                   (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
690                   int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
691               }
692             }
693           }
694           acc[m_index * n() + n_index] += bias[n_index];
695         }
696       }
697 
698       const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
699       const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
700       const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
701       const uint8_t c_zero_point = uint8_t(std::max(std::min(
702         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
703         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
704 
705       const float requantization_scale = 1.0f / float(c_scale);
706       union xnn_qs8_gemm_params quantization_params = { };
707       switch (variant) {
708         case Variant::Native:
709           quantization_params = xnn_init_qs8_gemm_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
710           break;
711         case Variant::Scalar:
712           quantization_params = xnn_init_scalar_qs8_gemm_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
713           break;
714       }
715       const union xnn_qs8_requantization_params scalar_requantization_params =
716         xnn_init_scalar_qs8_requantization_params(requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
717 
718       const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
719 
720       igemm(
721         m(), n(), k(), ks() * mr() * sizeof(void*),
722         im2col.data(), packed_w.data(),
723         c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
724         a_offset() * sizeof(uint8_t), zero_pointer,
725         &quantization_params);
726 
727       for (size_t m_index = 0; m_index < m(); m_index++) {
728         for (size_t n_index = 0; n_index < n(); n_index++) {
729           c_ref[m_index * n() + n_index] = xnn_qs8_requantize_q31(acc[m_index * n() + n_index], scalar_requantization_params);
730         }
731       }
732 
733       for (size_t i = 0; i < m(); i++) {
734         for (size_t j = 0; j < n(); j++) {
735           ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
736           ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
737           ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
738               << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
739               << " (accumulator = " << acc[i * n() + j]
740               << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
741               << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
742               << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
743         }
744       }
745     }
746   }
747 
748   void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, Variant variant = Variant::Native) const
749   {
750     ASSERT_LE(m(), mr());
751     ASSERT_GE(a_stride(), k());
752     ASSERT_GE(cm_stride(), n());
753 
754     std::random_device random_device;
755     auto rng = std::mt19937(random_device());
756     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
757     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
758 
759     std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
760     std::vector<uint16_t> b(n() * k());
761     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + bias_n());
762     std::vector<uint16_t> bias(n());
763     std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
764     std::vector<float> c_ref(m() * n());
765 
766     for (size_t iteration = 0; iteration < iterations(); iteration++) {
767       std::generate(a.begin(), a.end(), std::ref(f16rng));
768       std::generate(b.begin(), b.end(), std::ref(f16rng));
769       std::generate(bias.begin(), bias.end(), std::ref(f16rng));
770       std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
771       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
772 
773       std::fill(packed_w.begin(), packed_w.end(), 0);
774       xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), nullptr);
775 
776       for (size_t m_index = 0; m_index < m(); m_index++) {
777         for (size_t n_index = 0; n_index < n(); n_index++) {
778           for (size_t k_index = 0; k_index < k(); k_index++) {
779             ASSERT_LE(n(), packed_n());
780             ASSERT_LT(m_index * n() + n_index, c_ref.size());
781             ASSERT_LT(m_index * k() + k_index, a.size());
782             c_ref[m_index * n() + n_index] +=
783               fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_index]) *
784               fp16_ieee_to_fp32_value(b[n_index * k() + k_index]);
785           }
786           c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
787         }
788       }
789 
790       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
791       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
792       const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
793       const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
794 
795       // Prepare parameters.
796       xnn_f16_scaleminmax_params params = xnn_init_f16_scaleminmax_params(
797         UINT16_C(0x3C00) /* 1.0 */,
798         fp16_ieee_from_fp32_value(c_min),
799         fp16_ieee_from_fp32_value(c_max));
800 
801       for (float& c_value : c_ref) {
802         c_value = std::max(std::min(c_value, c_max), c_min);
803       }
804 
805       gemm_minmax(m(), n(), k() * sizeof(uint16_t),
806         a.data(), a_stride() * sizeof(uint16_t),
807         packed_w.data(),
808         c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
809         &params);
810 
811       // Validate micro-kernel outputs.
812       for (size_t i = 0; i < m(); i++) {
813         for (size_t j = 0; j < n(); j++) {
814           ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
815               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
816               << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
817               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
818         }
819       }
820     }
821   }
822 
Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax)823   void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax) const {
824     ASSERT_LE(m(), mr());
825 
826     std::random_device random_device;
827     auto rng = std::mt19937(random_device());
828     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
829     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
830 
831     std::vector<uint16_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
832     std::vector<uint16_t> b(n() * ks() * k());
833     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
834     std::vector<uint16_t> bias(n());
835     std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
836     std::vector<float> c_ref(m() * n());
837     std::vector<uint16_t> junk(k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
838     std::vector<const uint16_t*> im2col(mr() * ks());
839     std::fill(junk.begin(), junk.end(), UINT16_C(0x7E00) /* NaN */);
840 
841     for (size_t iteration = 0; iteration < iterations(); iteration++) {
842       std::generate(a.begin(), a.end(), std::ref(f16rng));
843       std::generate(b.begin(), b.end(), std::ref(f16rng));
844       std::generate(bias.begin(), bias.end(), std::ref(f16rng));
845       std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
846       std::fill(c_ref.begin(), c_ref.end(), 0);
847 
848       std::fill(packed_w.begin(), packed_w.end(), 0);
849       xnn_pack_f16_conv_goki_w(
850         1, n(), ks(), k(), nr(), kr(), sr(),
851         b.data(), bias.data(), packed_w.data(), nullptr);
852 
853       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
854         for (size_t m_index = 0; m_index < mr(); m_index++) {
855           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
856         }
857       }
858       std::shuffle(im2col.begin(), im2col.end(), rng);
859       if (zero_index() != SIZE_MAX) {
860         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
861           im2col[ks_index * mr() + zero_index()] = a.data();
862         }
863       }
864       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
865         for (size_t m_index = m(); m_index < mr(); m_index++) {
866           im2col[ks_index * mr() + m_index] = junk.data();
867         }
868       }
869 
870       std::fill(c_ref.begin(), c_ref.end(), 0.0);
871       for (size_t m_index = 0; m_index < m(); m_index++) {
872         for (size_t n_index = 0; n_index < n(); n_index++) {
873           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
874             for (size_t k_index = 0; k_index < k(); k_index++) {
875               ASSERT_LT(ks_index * mr() + m_index, im2col.size());
876               ASSERT_LT(k_index, k());
877               ASSERT_LT(k_index, a_stride());
878               if (im2col[ks_index * mr() + m_index] == a.data()) {
879                 c_ref[m_index * n() + n_index] +=
880                   fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index]) *
881                   fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
882               } else {
883                 c_ref[m_index * n() + n_index] +=
884                   fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
885                   fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
886               }
887             }
888           }
889           c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
890         }
891       }
892 
893       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
894       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
895       const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * uint16_t(qmin())));
896       const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * uint16_t(255 - qmax())));
897       for (size_t m_index = 0; m_index < m(); m_index++) {
898         for (size_t n_index = 0; n_index < n(); n_index++) {
899           c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
900           c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
901         }
902       }
903 
904       // Prepare parameters.
905       xnn_f16_scaleminmax_params params = xnn_init_f16_scaleminmax_params(
906         UINT16_C(0x3C00) /* 1.0 */,
907         fp16_ieee_from_fp32_value(c_min),
908         fp16_ieee_from_fp32_value(c_max));
909 
910       for (float& c_value : c_ref) {
911         c_value = std::max(std::min(c_value, c_max), c_min);
912       }
913 
914       const uint16_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
915 
916       igemm_minmax(
917         m(), n(), k() * sizeof(uint16_t), ks() * mr() * sizeof(void*),
918         reinterpret_cast<const void**>(im2col.data()), packed_w.data(),
919         c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
920         a_offset() * sizeof(uint16_t), zero_pointer,
921         &params);
922 
923       for (size_t i = 0; i < m(); i++) {
924         for (size_t j = 0; j < n(); j++) {
925           ASSERT_LE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_max)
926               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
927               << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
928               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
929           ASSERT_GE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_min)
930               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
931               << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
932               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
933           ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
934               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
935               << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
936               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
937         }
938       }
939     }
940   }
941 
942   void Test(xnn_f32_ppmm_minmax_ukernel_function ppmm, Variant variant = Variant::Native) const {
943     ASSERT_LE(m(), mr());
944     ASSERT_GE(cm_stride(), n());
945 
946     std::random_device random_device;
947     auto rng = std::mt19937(random_device());
948     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
949 
950     std::vector<float> a(packed_k() * mr());
951     std::vector<float> b(n() * k());
952     std::vector<float> bias(n());
953     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
954     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
955     std::vector<float> c_ref(m() * n());
956 
957     for (size_t iteration = 0; iteration < iterations(); iteration++) {
958       std::generate(a.begin(), a.end(), std::ref(f32rng));
959       std::generate(b.begin(), b.end(), std::ref(f32rng));
960       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
961       std::fill(c.begin(), c.end(), nanf(""));
962       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
963 
964       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
965       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), nullptr);
966 
967       for (size_t i = m(); i < mr(); i++) {
968         for (size_t l = 0; l < k(); l++) {
969           a[l * mr() + i] = a[l * mr() + m() - 1];
970         }
971       }
972 
973       for (size_t i = 0; i < m(); i++) {
974         for (size_t j = 0; j < n(); j++) {
975           for (size_t l = 0; l < k(); l++) {
976             c_ref[i * n() + j] +=
977               a[l * mr() + i] *
978               b[j * k() + l];
979           }
980           c_ref[i * n() + j] += bias[j];
981         }
982       }
983 
984       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
985       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
986       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
987       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
988 
989       // Prepare parameters.
990       xnn_f32_minmax_params params = { };
991       switch (variant) {
992         case Variant::Native:
993           params = xnn_init_f32_minmax_params(c_min, c_max);
994           break;
995         case Variant::Scalar:
996           params = xnn_init_scalar_f32_minmax_params(c_min, c_max);
997           break;
998       }
999 
1000       for (float& c_value : c_ref) {
1001         c_value = std::max(std::min(c_value, c_max), c_min);
1002       }
1003 
1004       ppmm(m(), n(), k() * sizeof(float),
1005         a.data(), packed_w.data(),
1006         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1007         &params);
1008 
1009       // Validate micro-kernel outputs.
1010       for (size_t i = 0; i < m(); i++) {
1011         for (size_t j = 0; j < n(); j++) {
1012           ASSERT_NEAR(
1013               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1014               c_ref[i * n() + j],
1015               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1016               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1017               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1018               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1019         }
1020       }
1021     }
1022   }
1023 
Test(xnn_f32_gemm_ukernel_function gemm)1024   void Test(xnn_f32_gemm_ukernel_function gemm) const {
1025     ASSERT_LE(m(), mr());
1026     ASSERT_GE(a_stride(), k());
1027     ASSERT_GE(cm_stride(), n());
1028 
1029     std::random_device random_device;
1030     auto rng = std::mt19937(random_device());
1031     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1032 
1033     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1034     std::vector<float> b(n() * k());
1035     std::vector<float> bias(n());
1036     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
1037     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1038     std::vector<float> c_ref(m() * n());
1039 
1040     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1041       std::generate(a.begin(), a.end(), std::ref(f32rng));
1042       std::generate(b.begin(), b.end(), std::ref(f32rng));
1043       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1044       std::fill(c.begin(), c.end(), nanf(""));
1045       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1046 
1047       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1048       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), nullptr);
1049 
1050       for (size_t m_index = 0; m_index < m(); m_index++) {
1051         for (size_t n_index = 0; n_index < n(); n_index++) {
1052           for (size_t k_index = 0; k_index < k(); k_index++) {
1053             ASSERT_LE(n(), packed_n());
1054             ASSERT_LT(m_index * n() + n_index, c_ref.size());
1055             c_ref[m_index * n() + n_index] +=
1056               a[m_index * a_stride() + k_index] *
1057               b[n_index * k() + k_index];
1058           }
1059           c_ref[m_index * n() + n_index] += bias[n_index];
1060         }
1061       }
1062 
1063       gemm(m(), n(), k() * sizeof(float),
1064         a.data(), a_stride() * sizeof(float),
1065         packed_w.data(),
1066         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1067         nullptr);
1068 
1069       // Validate micro-kernel outputs.
1070       for (size_t i = 0; i < m(); i++) {
1071         for (size_t j = 0; j < n(); j++) {
1072           ASSERT_NEAR(
1073               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1074               c_ref[i * n() + j],
1075               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1076               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1077               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1078               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1079         }
1080       }
1081     }
1082   }
1083 
Test(xnn_f32_gemm_relu_ukernel_function gemm_relu)1084   void Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const {
1085     ASSERT_LE(m(), mr());
1086     ASSERT_GE(a_stride(), k());
1087     ASSERT_GE(cm_stride(), n());
1088 
1089     std::random_device random_device;
1090     auto rng = std::mt19937(random_device());
1091     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1092 
1093     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1094     std::vector<float> b(n() * k());
1095     std::vector<float> bias(n());
1096     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
1097     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1098     std::vector<float> c_ref(m() * n());
1099 
1100     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1101       std::generate(a.begin(), a.end(), std::ref(f32rng));
1102       std::generate(b.begin(), b.end(), std::ref(f32rng));
1103       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1104       std::fill(c.begin(), c.end(), nanf(""));
1105       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1106 
1107       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1108       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), nullptr);
1109 
1110       for (size_t m_index = 0; m_index < m(); m_index++) {
1111         for (size_t n_index = 0; n_index < n(); n_index++) {
1112           for (size_t k_index = 0; k_index < k(); k_index++) {
1113             ASSERT_LE(n(), packed_n());
1114             ASSERT_LT(m_index * n() + n_index, c_ref.size());
1115             c_ref[m_index * n() + n_index] +=
1116               a[m_index * a_stride() + k_index] *
1117               b[n_index * k() + k_index];
1118           }
1119           c_ref[m_index * n() + n_index] = std::max(0.0f, c_ref[m_index * n() + n_index] + bias[n_index]);
1120         }
1121       }
1122 
1123       gemm_relu(m(), n(), k() * sizeof(float),
1124         a.data(), a_stride() * sizeof(float),
1125         packed_w.data(),
1126         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1127         nullptr);
1128 
1129       // Validate micro-kernel outputs.
1130       for (size_t i = 0; i < m(); i++) {
1131         for (size_t j = 0; j < n(); j++) {
1132           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1133               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1134               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1135               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1136           ASSERT_NEAR(
1137               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1138               c_ref[i * n() + j],
1139               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1140               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1141               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1142               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1143         }
1144       }
1145     }
1146   }
1147 
1148   void Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, Variant variant = Variant::Native) const {
1149     ASSERT_LE(m(), mr());
1150     ASSERT_GE(a_stride(), k());
1151     ASSERT_GE(cm_stride(), n());
1152 
1153     std::random_device random_device;
1154     auto rng = std::mt19937(random_device());
1155     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1156 
1157     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1158     std::vector<float> b(n() * k());
1159     std::vector<float> bias(n());
1160     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
1161     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1162     std::vector<float> c_ref(m() * n());
1163 
1164     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1165       std::generate(a.begin(), a.end(), std::ref(f32rng));
1166       std::generate(b.begin(), b.end(), std::ref(f32rng));
1167       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1168       std::fill(c.begin(), c.end(), nanf(""));
1169       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1170 
1171       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1172       xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), nullptr);
1173 
1174       for (size_t m_index = 0; m_index < m(); m_index++) {
1175         for (size_t n_index = 0; n_index < n(); n_index++) {
1176           for (size_t k_index = 0; k_index < k(); k_index++) {
1177             ASSERT_LE(n(), packed_n());
1178             ASSERT_LT(m_index * n() + n_index, c_ref.size());
1179             c_ref[m_index * n() + n_index] +=
1180               a[m_index * a_stride() + k_index] *
1181               b[n_index * k() + k_index];
1182           }
1183           c_ref[m_index * n() + n_index] += bias[n_index];
1184         }
1185       }
1186 
1187       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1188       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1189       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1190       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1191 
1192       // Prepare parameters.
1193       xnn_f32_minmax_params params = { };
1194       switch (variant) {
1195         case Variant::Native:
1196           params = xnn_init_f32_minmax_params(c_min, c_max);
1197           break;
1198         case Variant::Scalar:
1199           params = xnn_init_scalar_f32_minmax_params(c_min, c_max);
1200           break;
1201       }
1202 
1203       for (size_t m_index = 0; m_index < m(); m_index++) {
1204         for (size_t n_index = 0; n_index < n(); n_index++) {
1205           c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1206         }
1207       }
1208 
1209       gemm_minmax(m(), n(), k() * sizeof(float),
1210         a.data(), a_stride() * sizeof(float),
1211         packed_w.data(),
1212         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1213         &params);
1214 
1215       // Validate micro-kernel outputs.
1216       for (size_t i = 0; i < m(); i++) {
1217         for (size_t j = 0; j < n(); j++) {
1218           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1219               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1220               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1221               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1222           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1223               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1224               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1225               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1226           ASSERT_NEAR(
1227               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1228               c_ref[i * n() + j],
1229               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1230               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1231               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1232               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1233         }
1234       }
1235     }
1236   }
1237 
1238   void Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, Variant variant = Variant::Native) const {
1239     ASSERT_LE(m(), mr());
1240     ASSERT_GE(a_stride(), k());
1241     ASSERT_GE(cm_stride(), n());
1242 
1243     std::random_device random_device;
1244     auto rng = std::mt19937(random_device());
1245     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1246 
1247     std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1248     std::vector<float> b(n() * k());
1249     std::vector<float> bias(n());
1250     std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k());  // no bias_n()
1251     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1252     std::vector<float> c_ref(m() * n());
1253     std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
1254 
1255     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1256       std::generate(a.begin(), a.end(), std::ref(f32rng));
1257       std::generate(b.begin(), b.end(), std::ref(f32rng));
1258       std::fill(c.begin(), c.end(), nanf(""));
1259       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1260       std::generate(acc.begin(), acc.end(), std::ref(f32rng));
1261 
1262       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1263       xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data(), nullptr);
1264 
1265       for (size_t m_index = 0; m_index < m(); m_index++) {
1266         for (size_t n_index = 0; n_index < n(); n_index++) {
1267           for (size_t k_index = 0; k_index < k(); k_index++) {
1268             ASSERT_LE(n(), packed_n());
1269             ASSERT_LT(m_index * n() + n_index, c_ref.size());
1270             c_ref[m_index * n() + n_index] +=
1271               a[m_index * a_stride() + k_index] *
1272               b[n_index * k() + k_index];
1273           }
1274           c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
1275         }
1276       }
1277 
1278       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1279       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1280       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1281       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1282 
1283       // Prepare parameters.
1284       xnn_f32_minmax_params params = { };
1285       switch (variant) {
1286         case Variant::Native:
1287           params = xnn_init_f32_minmax_params(c_min, c_max);
1288           break;
1289         case Variant::Scalar:
1290           params = xnn_init_scalar_f32_minmax_params(c_min, c_max);
1291           break;
1292       }
1293 
1294       for (size_t m_index = 0; m_index < m(); m_index++) {
1295         for (size_t n_index = 0; n_index < n(); n_index++) {
1296           c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1297         }
1298       }
1299 
1300       gemminc(m(), n(), k() * sizeof(float),
1301         a.data(), a_stride() * sizeof(float),
1302         packed_w.data(),
1303         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1304         acc.data(),
1305         &params);
1306 
1307       // Validate micro-kernel outputs.
1308       for (size_t i = 0; i < m(); i++) {
1309         for (size_t j = 0; j < n(); j++) {
1310           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1311               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1312               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1313               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1314           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1315               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1316               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1317               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1318           ASSERT_NEAR(
1319               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1320               c_ref[i * n() + j],
1321               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1322               << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1323               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1324               << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1325         }
1326       }
1327     }
1328   }
1329 
Test(xnn_f32_igemm_ukernel_function igemm)1330   void Test(xnn_f32_igemm_ukernel_function igemm) const {
1331     ASSERT_LE(m(), mr());
1332 
1333     std::random_device random_device;
1334     auto rng = std::mt19937(random_device());
1335     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1336 
1337     std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1338     std::vector<float> b(n() * ks() * k());
1339     std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
1340     std::vector<float> bias(n());
1341     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1342     std::vector<float> c_ref(m() * n());
1343     std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1344     std::vector<const float*> im2col(mr() * ks());
1345     std::fill(junk.begin(), junk.end(), nanf(""));
1346 
1347     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1348       std::generate(a.begin(), a.end(), std::ref(f32rng));
1349       std::generate(b.begin(), b.end(), std::ref(f32rng));
1350       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1351       std::fill(c.begin(), c.end(), nanf(""));
1352       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1353 
1354       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1355       xnn_pack_f32_conv_goki_w(
1356         1, n(), ks(), k(), nr(), kr(), sr(),
1357         b.data(), bias.data(), packed_w.data(), nullptr);
1358 
1359       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1360         for (size_t m_index = 0; m_index < mr(); m_index++) {
1361           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1362         }
1363       }
1364       std::shuffle(im2col.begin(), im2col.end(), rng);
1365       if (zero_index() != SIZE_MAX) {
1366         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1367           im2col[ks_index * mr() + zero_index()] = a.data();
1368         }
1369       }
1370       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1371         for (size_t m_index = m(); m_index < mr(); m_index++) {
1372           im2col[ks_index * mr() + m_index] = junk.data();
1373         }
1374       }
1375 
1376       std::fill(c_ref.begin(), c_ref.end(), 0.0);
1377       for (size_t m_index = 0; m_index < m(); m_index++) {
1378         for (size_t n_index = 0; n_index < n(); n_index++) {
1379           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1380             for (size_t k_index = 0; k_index < k(); k_index++) {
1381               ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1382               ASSERT_LT(k_index, k());
1383               ASSERT_LT(k_index, a_stride());
1384               if (im2col[ks_index * mr() + m_index] == a.data()) {
1385                 c_ref[m_index * n() + n_index] +=
1386                   (im2col[ks_index * mr() + m_index][k_index]) *
1387                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1388               } else {
1389                 c_ref[m_index * n() + n_index] +=
1390                   (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1391                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1392               }
1393             }
1394           }
1395           c_ref[m_index * n() + n_index] += bias[n_index];
1396         }
1397       }
1398 
1399       const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1400 
1401       igemm(
1402         m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1403         im2col.data(), packed_w.data(),
1404         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1405         a_offset() * sizeof(float), zero_pointer,
1406         nullptr);
1407 
1408       for (size_t i = 0; i < m(); i++) {
1409         for (size_t j = 0; j < n(); j++) {
1410           ASSERT_NEAR(
1411               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1412               c_ref[i * n() + j],
1413               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1414               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1415               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1416               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1417         }
1418       }
1419     }
1420   }
1421 
Test(xnn_f32_igemm_relu_ukernel_function igemm_relu)1422   void Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const {
1423     ASSERT_LE(m(), mr());
1424 
1425     std::random_device random_device;
1426     auto rng = std::mt19937(random_device());
1427     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1428 
1429     std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1430     std::vector<float> b(n() * ks() * k());
1431     std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
1432     std::vector<float> bias(n());
1433     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1434     std::vector<float> c_ref(m() * n());
1435     std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1436     std::vector<const float*> im2col(mr() * ks());
1437     std::fill(junk.begin(), junk.end(), nanf(""));
1438 
1439     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1440       std::generate(a.begin(), a.end(), std::ref(f32rng));
1441       std::generate(b.begin(), b.end(), std::ref(f32rng));
1442       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1443       std::fill(c.begin(), c.end(), nanf(""));
1444       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1445 
1446       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1447       xnn_pack_f32_conv_goki_w(
1448         1, n(), ks(), k(), nr(), kr(), sr(),
1449         b.data(), bias.data(), packed_w.data(), nullptr);
1450 
1451       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1452         for (size_t m_index = 0; m_index < mr(); m_index++) {
1453           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1454         }
1455       }
1456       std::shuffle(im2col.begin(), im2col.end(), rng);
1457       if (zero_index() != SIZE_MAX) {
1458         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1459           im2col[ks_index * mr() + zero_index()] = a.data();
1460         }
1461       }
1462       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1463         for (size_t m_index = m(); m_index < mr(); m_index++) {
1464           im2col[ks_index * mr() + m_index] = junk.data();
1465         }
1466       }
1467 
1468       std::fill(c_ref.begin(), c_ref.end(), 0.0);
1469       for (size_t m_index = 0; m_index < m(); m_index++) {
1470         for (size_t n_index = 0; n_index < n(); n_index++) {
1471           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1472             for (size_t k_index = 0; k_index < k(); k_index++) {
1473               ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1474               ASSERT_LT(k_index, k());
1475               ASSERT_LT(k_index, a_stride());
1476               if (im2col[ks_index * mr() + m_index] == a.data()) {
1477                 c_ref[m_index * n() + n_index] +=
1478                   (im2col[ks_index * mr() + m_index][k_index]) *
1479                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1480               } else {
1481                 c_ref[m_index * n() + n_index] +=
1482                   (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1483                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1484               }
1485             }
1486           }
1487           c_ref[m_index * n() + n_index] = std::max(0.0f, bias[n_index] + c_ref[m_index * n() + n_index]);
1488         }
1489       }
1490 
1491       const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1492 
1493       igemm_relu(
1494         m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1495         im2col.data(), packed_w.data(),
1496         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1497         a_offset() * sizeof(float), zero_pointer,
1498         nullptr);
1499 
1500       for (size_t i = 0; i < m(); i++) {
1501         for (size_t j = 0; j < n(); j++) {
1502           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1503               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1504               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1505               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1506           ASSERT_NEAR(
1507               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1508               c_ref[i * n() + j],
1509               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1510               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1511               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1512               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1513         }
1514       }
1515     }
1516   }
1517 
1518   void Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, Variant variant = Variant::Native) const {
1519     ASSERT_LE(m(), mr());
1520 
1521     std::random_device random_device;
1522     auto rng = std::mt19937(random_device());
1523     auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
1524 
1525     std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1526     std::vector<float> b(n() * ks() * k());
1527     std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
1528     std::vector<float> bias(n());
1529     std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1530     std::vector<float> c_ref(m() * n());
1531     std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1532     std::vector<const float*> im2col(mr() * ks());
1533     std::fill(junk.begin(), junk.end(), nanf(""));
1534 
1535     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1536       std::generate(a.begin(), a.end(), std::ref(f32rng));
1537       std::generate(b.begin(), b.end(), std::ref(f32rng));
1538       std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1539       std::fill(c.begin(), c.end(), nanf(""));
1540       std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1541 
1542       std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1543       xnn_pack_f32_conv_goki_w(
1544         1, n(), ks(), k(), nr(), kr(), sr(),
1545         b.data(), bias.data(), packed_w.data(), nullptr);
1546 
1547       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1548         for (size_t m_index = 0; m_index < mr(); m_index++) {
1549           im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1550         }
1551       }
1552       std::shuffle(im2col.begin(), im2col.end(), rng);
1553       if (zero_index() != SIZE_MAX) {
1554         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1555           im2col[ks_index * mr() + zero_index()] = a.data();
1556         }
1557       }
1558       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1559         for (size_t m_index = m(); m_index < mr(); m_index++) {
1560           im2col[ks_index * mr() + m_index] = junk.data();
1561         }
1562       }
1563 
1564       std::fill(c_ref.begin(), c_ref.end(), 0.0);
1565       for (size_t m_index = 0; m_index < m(); m_index++) {
1566         for (size_t n_index = 0; n_index < n(); n_index++) {
1567           for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1568             for (size_t k_index = 0; k_index < k(); k_index++) {
1569               ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1570               ASSERT_LT(k_index, k());
1571               ASSERT_LT(k_index, a_stride());
1572               if (im2col[ks_index * mr() + m_index] == a.data()) {
1573                 c_ref[m_index * n() + n_index] +=
1574                   (im2col[ks_index * mr() + m_index][k_index]) *
1575                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1576               } else {
1577                 c_ref[m_index * n() + n_index] +=
1578                   (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1579                   (b[(n_index * ks() + ks_index) * k() + k_index]);
1580               }
1581             }
1582           }
1583           c_ref[m_index * n() + n_index] += bias[n_index];
1584         }
1585       }
1586 
1587       const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1588       const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1589       const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1590       const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1591       for (size_t m_index = 0; m_index < m(); m_index++) {
1592         for (size_t n_index = 0; n_index < n(); n_index++) {
1593           c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1594           c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1595         }
1596       }
1597 
1598       // Prepare parameters.
1599       xnn_f32_minmax_params params = { };
1600       switch (variant) {
1601         case Variant::Native:
1602           params = xnn_init_f32_minmax_params(c_min, c_max);
1603           break;
1604         case Variant::Scalar:
1605           params = xnn_init_scalar_f32_minmax_params(c_min, c_max);
1606           break;
1607       }
1608 
1609       const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1610 
1611       igemm_minmax(
1612         m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1613         im2col.data(), packed_w.data(),
1614         c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1615         a_offset() * sizeof(float), zero_pointer,
1616         &params);
1617 
1618       for (size_t i = 0; i < m(); i++) {
1619         for (size_t j = 0; j < n(); j++) {
1620           ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1621               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1622               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1623               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1624           ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1625               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1626               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1627               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1628           ASSERT_NEAR(
1629               c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1630               c_ref[i * n() + j],
1631               std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1632               << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1633               << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1634               << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1635         }
1636       }
1637     }
1638   }
1639 
1640  private:
1641   size_t mr_{1};
1642   size_t nr_{1};
1643   size_t kr_{1};
1644   size_t sr_{1};
1645   size_t m_{1};
1646   size_t n_{1};
1647   size_t k_{1};
1648   size_t ks_{1};
1649   size_t a_stride_{0};
1650   size_t cm_stride_{0};
1651   size_t cn_stride_{0};
1652   uint8_t a_zero_point_{127};
1653   uint8_t b_zero_point_{127};
1654   uint8_t qmin_{0};
1655   uint8_t qmax_{255};
1656   size_t a_offset_{0};
1657   size_t zero_index_{SIZE_MAX};
1658   size_t iterations_{15};
1659 };
1660