• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "gemm-microkernel-tester.h"
2 
3 #include <gtest/gtest.h>
4 
5 #include <algorithm>
6 #include <cassert>
7 #include <cmath>
8 #include <cstddef>
9 #include <cstdlib>
10 #include <functional>
11 #include <limits>
12 #include <numeric>
13 #include <random>
14 #include <vector>
15 
16 #include <fp16.h>
17 
18 #include <xnnpack.h>
19 #include <xnnpack/allocator.h>
20 #include <xnnpack/AlignedAllocator.h>
21 #include <xnnpack/pack.h>
22 #include <xnnpack/params-init.h>
23 #include <xnnpack/params.h>
24 #include <xnnpack/requantization.h>
25 
Test(xnn_qu8_gemm_minmax_ukernel_function gemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize) const26 void GemmMicrokernelTester::Test(
27   xnn_qu8_gemm_minmax_ukernel_function gemm,
28   xnn_init_qu8_conv_minmax_params_fn init_params,
29   xnn_qu8_requantize_fn requantize) const
30 {
31   ASSERT_LE(m(), mr());
32 
33   std::random_device random_device;
34   auto rng = std::mt19937(random_device());
35   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
36   auto u8rng = std::bind(
37     std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
38 
39   std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
40   std::vector<uint8_t> b(n() * k());
41   std::vector<int32_t> bias(n());
42   std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
43   std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
44   std::vector<int32_t> acc(m() * n());
45   std::vector<uint8_t> c_ref(m() * n());
46 
47   for (size_t iteration = 0; iteration < iterations(); iteration++) {
48     do {
49       std::generate(a.begin(), a.end(), std::ref(u8rng));
50     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
51     do {
52       std::generate(b.begin(), b.end(), std::ref(u8rng));
53     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
54     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
55     std::fill(c.begin(), c.end(), 0xA5);
56 
57     std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
58     const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
59     xnn_pack_qu8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
60       b.data(), bias.data(), packed_w.data(), 0, &packing_params);
61 
62     // Compute 32-bit results and output quantization arguments.
63     std::fill(acc.begin(), acc.end(), 0);
64     for (size_t m_index = 0; m_index < m(); m_index++) {
65       for (size_t n_index = 0; n_index < n(); n_index++) {
66         for (size_t k_index = 0; k_index < k(); k_index++) {
67           acc[m_index * n() + n_index] +=
68               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
69               (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
70         }
71         acc[m_index * n() + n_index] += bias[n_index];
72       }
73     }
74 
75     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
76     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
77     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
78     const uint8_t c_zero_point = uint8_t(std::max(std::min(
79       lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
80       long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
81 
82     const float requantization_scale = 1.0f / float(c_scale);
83     union xnn_qu8_conv_minmax_params quantization_params;
84     init_params(&quantization_params,
85       b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
86 
87     gemm(
88       m(), n(), k(),
89       a.data(), a_stride() * sizeof(uint8_t),
90       packed_w.data(),
91       c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
92       &quantization_params);
93 
94     for (size_t m_index = 0; m_index < m(); m_index++) {
95       for (size_t n_index = 0; n_index < n(); n_index++) {
96         c_ref[m_index * n() + n_index] = requantize(
97           acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
98       }
99     }
100 
101     for (size_t i = 0; i < m(); i++) {
102       for (size_t j = 0; j < n(); j++) {
103         ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
104         ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
105         ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
106             << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
107             << " (accumulator = " << acc[i * n() + j]
108             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
109             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
110             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
111       }
112     }
113   }
114 }
115 
Test(xnn_qu8_igemm_minmax_ukernel_function igemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)116 void GemmMicrokernelTester::Test(
117   xnn_qu8_igemm_minmax_ukernel_function igemm,
118   xnn_init_qu8_conv_minmax_params_fn init_params,
119   xnn_qu8_requantize_fn requantize)
120 {
121   ASSERT_LE(m(), mr());
122 
123   std::random_device random_device;
124   auto rng = std::mt19937(random_device());
125   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
126   auto u8rng = std::bind(
127     std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
128 
129   std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
130   std::vector<uint8_t> b(n() * ks() * k());
131   std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
132   std::vector<int32_t> bias(n());
133   std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
134   std::vector<int32_t> acc(m() * n());
135   std::vector<uint8_t> c_ref(m() * n());
136   std::vector<uint8_t> junk(k() + 8);
137   std::vector<const uint8_t*> im2col(mr() * ks());
138 
139   std::fill(junk.begin(), junk.end(), 0xA5);
140 
141   for (size_t iteration = 0; iteration < iterations(); iteration++) {
142     do {
143       std::generate(a.begin(), a.end(), std::ref(u8rng));
144     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
145     do {
146       std::generate(b.begin(), b.end(), std::ref(u8rng));
147     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
148     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
149     std::fill(c.begin(), c.end(), 0xA5);
150 
151     std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
152     const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
153     xnn_pack_qu8_conv_goki_w(
154       1, n(), ks(), k(), nr(), kr(), sr(),
155       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
156 
157     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
158       for (size_t m_index = 0; m_index < mr(); m_index++) {
159         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
160       }
161     }
162     std::shuffle(im2col.begin(), im2col.end(), rng);
163     if (zero_index() != SIZE_MAX) {
164       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
165         im2col[ks_index * mr() + zero_index()] = a.data();
166       }
167     }
168     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
169       for (size_t m_index = m(); m_index < mr(); m_index++) {
170         im2col[ks_index * mr() + m_index] = junk.data();
171       }
172     }
173 
174     // Compute 32-bit results and output quantization arguments.
175     std::fill(acc.begin(), acc.end(), 0);
176     for (size_t m_index = 0; m_index < m(); m_index++) {
177       for (size_t n_index = 0; n_index < n(); n_index++) {
178         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
179           for (size_t k_index = 0; k_index < k(); k_index++) {
180             if (im2col[ks_index * mr() + m_index] == a.data()) {
181               acc[m_index * n() + n_index] +=
182                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point())) *
183                 (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
184             } else {
185               acc[m_index * n() + n_index] +=
186                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point())) *
187                 (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
188             }
189           }
190         }
191         acc[m_index * n() + n_index] += bias[n_index];
192       }
193     }
194 
195     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
196     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
197     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
198     const uint8_t c_zero_point = uint8_t(std::max(std::min(
199       lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
200       long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
201 
202     const float requantization_scale = 1.0f / float(c_scale);
203     union xnn_qu8_conv_minmax_params quantization_params;
204     init_params(&quantization_params,
205       b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
206 
207     const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
208 
209     igemm(
210       m(), n(), k(), ks() * mr() * sizeof(void*),
211       im2col.data(), packed_w.data(),
212       c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
213       a_offset() * sizeof(uint8_t), zero_pointer,
214       &quantization_params);
215 
216     for (size_t m_index = 0; m_index < m(); m_index++) {
217       for (size_t n_index = 0; n_index < n(); n_index++) {
218         c_ref[m_index * n() + n_index] = requantize(
219           acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
220       }
221     }
222 
223     for (size_t i = 0; i < m(); i++) {
224       for (size_t j = 0; j < n(); j++) {
225         ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
226         ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
227         ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
228             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
229             << " (accumulator = " << acc[i * n() + j]
230             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
231             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
232             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
233       }
234     }
235   }
236 }
237 
Test(xnn_qc8_gemm_minmax_ukernel_function gemm,xnn_init_qs8_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const238 void GemmMicrokernelTester::Test(
239   xnn_qc8_gemm_minmax_ukernel_function gemm,
240   xnn_init_qs8_minmax_params_fn init_params,
241   xnn_qs8_requantize_fn requantize) const
242 {
243   ASSERT_LE(m(), mr());
244 
245   std::random_device random_device;
246   auto rng = std::mt19937(random_device());
247   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
248   auto i8rng = std::bind(
249     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
250     std::ref(rng));
251   auto w8rng = std::bind(
252     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
253     std::ref(rng));
254 
255   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
256   std::vector<int8_t> b(n() * k());
257   std::vector<int32_t> bias(n());
258   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
259   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
260   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
261   std::vector<int32_t> acc(m() * n());
262   std::vector<float> scale(n());
263   std::vector<int8_t> c_ref(m() * n());
264 
265   for (size_t iteration = 0; iteration < iterations(); iteration++) {
266     do {
267       std::generate(a.begin(), a.end(), std::ref(i8rng));
268     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
269     do {
270       std::generate(b.begin(), b.end(), std::ref(w8rng));
271     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
272     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
273     std::fill(c.begin(), c.end(), 0xA5);
274 
275     std::fill(packed_w.begin(), packed_w.end(), 0);
276     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
277     if (extended_weights()) {
278       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
279         b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
280     } else {
281       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
282         b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
283     }
284 
285     // Compute 32-bit results and output quantization arguments.
286     std::fill(acc.begin(), acc.end(), 0);
287     for (size_t m_index = 0; m_index < m(); m_index++) {
288       for (size_t n_index = 0; n_index < n(); n_index++) {
289         for (size_t k_index = 0; k_index < k(); k_index++) {
290           acc[m_index * n() + n_index] +=
291               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
292               int32_t(b[n_index * k() + k_index]);
293         }
294         acc[m_index * n() + n_index] += bias[n_index];
295       }
296     }
297 
298     const int8_t c_zero_point = -1;
299     for (size_t n_index = 0; n_index < n(); n_index++) {
300       int32_t accumulated_min = acc[n_index];
301       int32_t accumulated_max = acc[n_index];
302       for (size_t m_index = 0; m_index < m(); m_index++) {
303         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
304         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
305       }
306       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
307       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
308       scale[n_index] = 1.0f / c_scale;
309     }
310 
311     if (extended_weights()) {
312       xnn_init_qc8_scale_fp32_params(
313         n(), nr(),
314         nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
315         (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
316     } else {
317       xnn_init_qc8_scale_fp32_params(
318         n(), nr(),
319         nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
320         (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
321     }
322 
323     union xnn_qs8_minmax_params minmax_params;
324     init_params(&minmax_params,
325       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
326 
327     gemm(
328       m(), n(), k(),
329       a.data(), a_stride() * sizeof(int8_t),
330       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
331       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
332       &minmax_params);
333 
334     for (size_t m_index = 0; m_index < m(); m_index++) {
335       for (size_t n_index = 0; n_index < n(); n_index++) {
336         c_ref[m_index * n() + n_index] = requantize(
337           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
338       }
339     }
340 
341     for (size_t i = 0; i < m(); i++) {
342       for (size_t j = 0; j < n(); j++) {
343         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
344         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
345         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
346             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
347             << " (accumulator = " << acc[i * n() + j]
348             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
349             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
350             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
351       }
352     }
353   }
354 }
355 
Test(xnn_qc8_igemm_minmax_ukernel_function igemm,xnn_init_qs8_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const356 void GemmMicrokernelTester::Test(
357   xnn_qc8_igemm_minmax_ukernel_function igemm,
358   xnn_init_qs8_minmax_params_fn init_params,
359   xnn_qs8_requantize_fn requantize) const
360 {
361   ASSERT_LE(m(), mr());
362 
363   std::random_device random_device;
364   auto rng = std::mt19937(random_device());
365   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
366   auto i8rng = std::bind(
367     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
368     std::ref(rng));
369   auto w8rng = std::bind(
370     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
371     std::ref(rng));
372 
373   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
374   std::vector<int8_t> b(n() * ks() * k());
375   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
376   std::vector<int32_t> bias(n());
377   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
378   std::vector<int32_t> acc(m() * n());
379   std::vector<float> scale(n());
380   std::vector<int8_t> c_ref(m() * n());
381   std::vector<int8_t> junk(k() + 8);
382   std::vector<const int8_t*> im2col(mr() * ks());
383 
384   std::fill(junk.begin(), junk.end(), 0xA5);
385 
386   for (size_t iteration = 0; iteration < iterations(); iteration++) {
387     do {
388       std::generate(a.begin(), a.end(), std::ref(i8rng));
389     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
390     do {
391       std::generate(b.begin(), b.end(), std::ref(w8rng));
392     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
393     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
394     std::fill(c.begin(), c.end(), 0xA5);
395 
396     std::fill(packed_w.begin(), packed_w.end(), 0);
397     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
398     xnn_pack_qs8_conv_goki_w(
399       1, n(), ks(), k(), nr(), kr(), sr(),
400       b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
401 
402     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
403       for (size_t m_index = 0; m_index < mr(); m_index++) {
404         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
405       }
406     }
407     std::shuffle(im2col.begin(), im2col.end(), rng);
408     if (zero_index() != SIZE_MAX) {
409       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
410         im2col[ks_index * mr() + zero_index()] = a.data();
411       }
412     }
413     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
414       for (size_t m_index = m(); m_index < mr(); m_index++) {
415         im2col[ks_index * mr() + m_index] = junk.data();
416       }
417     }
418 
419     // Compute 32-bit results and output quantization arguments.
420     std::fill(acc.begin(), acc.end(), 0);
421     for (size_t m_index = 0; m_index < m(); m_index++) {
422       for (size_t n_index = 0; n_index < n(); n_index++) {
423         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
424           for (size_t k_index = 0; k_index < k(); k_index++) {
425             if (im2col[ks_index * mr() + m_index] == a.data()) {
426               acc[m_index * n() + n_index] +=
427                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
428                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
429             } else {
430               acc[m_index * n() + n_index] +=
431                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
432                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
433             }
434           }
435         }
436         acc[m_index * n() + n_index] += bias[n_index];
437       }
438     }
439 
440     const int8_t c_zero_point = -1;
441     for (size_t n_index = 0; n_index < n(); n_index++) {
442       int32_t accumulated_min = acc[n_index];
443       int32_t accumulated_max = acc[n_index];
444       for (size_t m_index = 0; m_index < m(); m_index++) {
445         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
446         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
447       }
448       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
449       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
450       scale[n_index] = 1.0f / c_scale;
451     }
452 
453     xnn_init_qc8_scale_fp32_params(
454       n(), nr(),
455       nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
456       (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
457 
458     union xnn_qs8_minmax_params minmax_params;
459     init_params(&minmax_params,
460       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
461 
462     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
463 
464     igemm(
465       m(), n(), k(), ks() * mr() * sizeof(void*),
466       im2col.data(), packed_w.data(),
467       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
468       a_offset() * sizeof(uint8_t), zero_pointer,
469       &minmax_params);
470 
471     for (size_t m_index = 0; m_index < m(); m_index++) {
472       for (size_t n_index = 0; n_index < n(); n_index++) {
473         c_ref[m_index * n() + n_index] = requantize(
474           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
475       }
476     }
477 
478     for (size_t i = 0; i < m(); i++) {
479       for (size_t j = 0; j < n(); j++) {
480         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
481         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
482         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
483             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
484             << " (accumulator = " << acc[i * n() + j]
485             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
486             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
487             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
488       }
489     }
490   }
491 }
492 
Test(xnn_qs8_gemm_minmax_ukernel_function gemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const493 void GemmMicrokernelTester::Test(
494   xnn_qs8_gemm_minmax_ukernel_function gemm,
495   xnn_init_qs8_conv_minmax_params_fn init_params,
496   xnn_qs8_requantize_fn requantize) const
497 {
498   ASSERT_LE(m(), mr());
499 
500   std::random_device random_device;
501   auto rng = std::mt19937(random_device());
502   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
503   auto i8rng = std::bind(
504     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
505     std::ref(rng));
506   auto w8rng = std::bind(
507     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
508     std::ref(rng));
509 
510   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
511   std::vector<int8_t> b(n() * k());
512   std::vector<int32_t> bias(n());
513   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
514   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
515   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
516   std::vector<int32_t> acc(m() * n());
517   std::vector<int8_t> c_ref(m() * n());
518 
519   for (size_t iteration = 0; iteration < iterations(); iteration++) {
520     do {
521       std::generate(a.begin(), a.end(), std::ref(i8rng));
522     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
523     do {
524       std::generate(b.begin(), b.end(), std::ref(w8rng));
525     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
526     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
527     std::fill(c.begin(), c.end(), 0xA5);
528 
529     std::fill(packed_w.begin(), packed_w.end(), 0);
530     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
531     if (extended_weights()) {
532       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
533         b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
534     } else {
535       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
536         b.data(), bias.data(), packed_w.data(), 0, &packing_params);
537     }
538 
539     // Compute 32-bit results and output quantization arguments.
540     std::fill(acc.begin(), acc.end(), 0);
541     for (size_t m_index = 0; m_index < m(); m_index++) {
542       for (size_t n_index = 0; n_index < n(); n_index++) {
543         for (size_t k_index = 0; k_index < k(); k_index++) {
544           acc[m_index * n() + n_index] +=
545               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
546               int32_t(b[n_index * k() + k_index]);
547         }
548         acc[m_index * n() + n_index] += bias[n_index];
549       }
550     }
551 
552     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
553     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
554     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
555     const int8_t c_zero_point = int8_t(std::max(std::min(
556       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
557       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
558 
559     const float requantization_scale = 1.0f / float(c_scale);
560     union xnn_qs8_conv_minmax_params quantization_params;
561     init_params(&quantization_params,
562       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
563 
564     gemm(
565       m(), n(), k(),
566       a.data(), a_stride() * sizeof(int8_t),
567       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
568       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
569       &quantization_params);
570 
571     for (size_t m_index = 0; m_index < m(); m_index++) {
572       for (size_t n_index = 0; n_index < n(); n_index++) {
573         c_ref[m_index * n() + n_index] = requantize(
574           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
575       }
576     }
577 
578     for (size_t i = 0; i < m(); i++) {
579       for (size_t j = 0; j < n(); j++) {
580         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
581         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
582         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
583             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
584             << " (accumulator = " << acc[i * n() + j]
585             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
586             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
587             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
588       }
589     }
590   }
591 }
592 
Test(xnn_qs8_igemm_minmax_ukernel_function igemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const593 void GemmMicrokernelTester::Test(
594   xnn_qs8_igemm_minmax_ukernel_function igemm,
595   xnn_init_qs8_conv_minmax_params_fn init_params,
596   xnn_qs8_requantize_fn requantize) const
597 {
598   ASSERT_LE(m(), mr());
599 
600   std::random_device random_device;
601   auto rng = std::mt19937(random_device());
602   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
603   auto i8rng = std::bind(
604     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
605     std::ref(rng));
606   auto w8rng = std::bind(
607     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
608     std::ref(rng));
609 
610   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
611   std::vector<int8_t> b(n() * ks() * k());
612   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
613   std::vector<int32_t> bias(n());
614   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
615   std::vector<int32_t> acc(m() * n());
616   std::vector<int8_t> c_ref(m() * n());
617   std::vector<int8_t> junk(k() + 8);
618   std::vector<const int8_t*> im2col(mr() * ks());
619 
620   std::fill(junk.begin(), junk.end(), 0xA5);
621 
622   for (size_t iteration = 0; iteration < iterations(); iteration++) {
623     do {
624       std::generate(a.begin(), a.end(), std::ref(i8rng));
625     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
626     do {
627       std::generate(b.begin(), b.end(), std::ref(w8rng));
628     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
629     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
630     std::fill(c.begin(), c.end(), 0xA5);
631 
632     std::fill(packed_w.begin(), packed_w.end(), 0);
633     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
634     xnn_pack_qs8_conv_goki_w(
635       1, n(), ks(), k(), nr(), kr(), sr(),
636       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
637 
638     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
639       for (size_t m_index = 0; m_index < mr(); m_index++) {
640         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
641       }
642     }
643     std::shuffle(im2col.begin(), im2col.end(), rng);
644     if (zero_index() != SIZE_MAX) {
645       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
646         im2col[ks_index * mr() + zero_index()] = a.data();
647       }
648     }
649     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
650       for (size_t m_index = m(); m_index < mr(); m_index++) {
651         im2col[ks_index * mr() + m_index] = junk.data();
652       }
653     }
654 
655     // Compute 32-bit results and output quantization arguments.
656     std::fill(acc.begin(), acc.end(), 0);
657     for (size_t m_index = 0; m_index < m(); m_index++) {
658       for (size_t n_index = 0; n_index < n(); n_index++) {
659         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
660           for (size_t k_index = 0; k_index < k(); k_index++) {
661             if (im2col[ks_index * mr() + m_index] == a.data()) {
662               acc[m_index * n() + n_index] +=
663                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
664                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
665             } else {
666               acc[m_index * n() + n_index] +=
667                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
668                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
669             }
670           }
671         }
672         acc[m_index * n() + n_index] += bias[n_index];
673       }
674     }
675 
676     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
677     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
678     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
679     const uint8_t c_zero_point = uint8_t(std::max(std::min(
680       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
681       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
682 
683     const float requantization_scale = 1.0f / float(c_scale);
684     union xnn_qs8_conv_minmax_params quantization_params;
685     init_params(&quantization_params,
686       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
687 
688     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
689 
690     igemm(
691       m(), n(), k(), ks() * mr() * sizeof(void*),
692       im2col.data(), packed_w.data(),
693       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
694       a_offset() * sizeof(uint8_t), zero_pointer,
695       &quantization_params);
696 
697     for (size_t m_index = 0; m_index < m(); m_index++) {
698       for (size_t n_index = 0; n_index < n(); n_index++) {
699         c_ref[m_index * n() + n_index] = requantize(
700           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
701       }
702     }
703 
704     for (size_t i = 0; i < m(); i++) {
705       for (size_t j = 0; j < n(); j++) {
706         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
707         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
708         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
709             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
710             << " (accumulator = " << acc[i * n() + j]
711             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
712             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
713             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
714       }
715     }
716   }
717 }
718 
Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f16_scaleminmax_params_fn init_params) const719 void GemmMicrokernelTester::Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const
720 {
721   ASSERT_LE(m(), mr());
722   ASSERT_GE(a_stride(), k());
723   ASSERT_GE(cm_stride(), n());
724 
725   std::random_device random_device;
726   auto rng = std::mt19937(random_device());
727   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
728   auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
729 
730   std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
731   std::vector<uint16_t> b(n() * k());
732   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + packed_n());
733   std::vector<uint16_t> bias(n());
734   std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
735   std::vector<float> c_ref(m() * n());
736 
737   for (size_t iteration = 0; iteration < iterations(); iteration++) {
738     std::generate(a.begin(), a.end(), std::ref(f16rng));
739     std::generate(b.begin(), b.end(), std::ref(f16rng));
740     std::generate(bias.begin(), bias.end(), std::ref(f16rng));
741     std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
742     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
743 
744     std::fill(packed_w.begin(), packed_w.end(), 0);
745     xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
746 
747     for (size_t m_index = 0; m_index < m(); m_index++) {
748       for (size_t n_index = 0; n_index < n(); n_index++) {
749         for (size_t k_index = 0; k_index < k(); k_index++) {
750           ASSERT_LE(n(), packed_n());
751           ASSERT_LT(m_index * n() + n_index, c_ref.size());
752           ASSERT_LT(m_index * k() + k_index, a.size());
753           c_ref[m_index * n() + n_index] +=
754             fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_index]) *
755             fp16_ieee_to_fp32_value(b[n_index * k() + k_index]);
756         }
757         c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
758       }
759     }
760 
761     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
762     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
763     const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
764     const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
765 
766     // Prepare parameters.
767     xnn_f16_scaleminmax_params params;
768     init_params(&params,
769       UINT16_C(0x3C00) /* 1.0 */,
770       fp16_ieee_from_fp32_value(c_min),
771       fp16_ieee_from_fp32_value(c_max));
772 
773     for (float& c_value : c_ref) {
774       c_value = std::max(std::min(c_value, c_max), c_min);
775     }
776 
777     gemm_minmax(m(), n(), k() * sizeof(uint16_t),
778       a.data(), a_stride() * sizeof(uint16_t),
779       packed_w.data(),
780       c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
781       &params);
782 
783     // Validate micro-kernel outputs.
784     for (size_t i = 0; i < m(); i++) {
785       for (size_t j = 0; j < n(); j++) {
786         ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
787             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
788             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
789             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
790       }
791     }
792   }
793 }
794 
Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f16_scaleminmax_params_fn init_params) const795 void GemmMicrokernelTester::Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const {
796   ASSERT_LE(m(), mr());
797 
798   std::random_device random_device;
799   auto rng = std::mt19937(random_device());
800   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
801   auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
802 
803   std::vector<uint16_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
804   std::vector<uint16_t> b(n() * ks() * k());
805   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
806   std::vector<uint16_t> bias(n());
807   std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
808   std::vector<float> c_ref(m() * n());
809   std::vector<uint16_t> junk(k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
810   std::vector<const uint16_t*> im2col(mr() * ks());
811   std::fill(junk.begin(), junk.end(), UINT16_C(0x7E00) /* NaN */);
812 
813   for (size_t iteration = 0; iteration < iterations(); iteration++) {
814     std::generate(a.begin(), a.end(), std::ref(f16rng));
815     std::generate(b.begin(), b.end(), std::ref(f16rng));
816     std::generate(bias.begin(), bias.end(), std::ref(f16rng));
817     std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
818     std::fill(c_ref.begin(), c_ref.end(), 0);
819 
820     std::fill(packed_w.begin(), packed_w.end(), 0);
821     xnn_pack_f16_conv_goki_w(
822       1, n(), ks(), k(), nr(), kr(), sr(),
823       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
824 
825     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
826       for (size_t m_index = 0; m_index < mr(); m_index++) {
827         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
828       }
829     }
830     std::shuffle(im2col.begin(), im2col.end(), rng);
831     if (zero_index() != SIZE_MAX) {
832       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
833         im2col[ks_index * mr() + zero_index()] = a.data();
834       }
835     }
836     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
837       for (size_t m_index = m(); m_index < mr(); m_index++) {
838         im2col[ks_index * mr() + m_index] = junk.data();
839       }
840     }
841 
842     std::fill(c_ref.begin(), c_ref.end(), 0.0);
843     for (size_t m_index = 0; m_index < m(); m_index++) {
844       for (size_t n_index = 0; n_index < n(); n_index++) {
845         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
846           for (size_t k_index = 0; k_index < k(); k_index++) {
847             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
848             ASSERT_LT(k_index, k());
849             ASSERT_LT(k_index, a_stride());
850             if (im2col[ks_index * mr() + m_index] == a.data()) {
851               c_ref[m_index * n() + n_index] +=
852                 fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index]) *
853                 fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
854             } else {
855               c_ref[m_index * n() + n_index] +=
856                 fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
857                 fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
858             }
859           }
860         }
861         c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
862       }
863     }
864 
865     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
866     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
867     const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * uint16_t(qmin())));
868     const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * uint16_t(255 - qmax())));
869     for (size_t m_index = 0; m_index < m(); m_index++) {
870       for (size_t n_index = 0; n_index < n(); n_index++) {
871         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
872         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
873       }
874     }
875 
876     // Prepare parameters.
877     xnn_f16_scaleminmax_params params;
878     init_params(&params,
879       UINT16_C(0x3C00) /* 1.0 */,
880       fp16_ieee_from_fp32_value(c_min),
881       fp16_ieee_from_fp32_value(c_max));
882 
883     for (float& c_value : c_ref) {
884       c_value = std::max(std::min(c_value, c_max), c_min);
885     }
886 
887     const uint16_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
888 
889     igemm_minmax(
890       m(), n(), k() * sizeof(uint16_t), ks() * mr() * sizeof(void*),
891       reinterpret_cast<const void**>(im2col.data()), packed_w.data(),
892       c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
893       a_offset() * sizeof(uint16_t), zero_pointer,
894       &params);
895 
896     for (size_t i = 0; i < m(); i++) {
897       for (size_t j = 0; j < n(); j++) {
898         ASSERT_LE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_max)
899             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
900             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
901             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
902         ASSERT_GE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_min)
903             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
904             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
905             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
906         ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
907             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
908             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
909             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
910       }
911     }
912   }
913 }
914 
Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax,xnn_init_f32_minmax_params_fn init_params) const915 void GemmMicrokernelTester::Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
916   ASSERT_LE(m(), mr());
917   ASSERT_GE(cm_stride(), n());
918 
919   std::random_device random_device;
920   auto rng = std::mt19937(random_device());
921   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
922 
923   std::vector<float> a(packed_k() * mr());
924   std::vector<float> b(n() * k());
925   std::vector<float> bias(n());
926   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
927   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
928   std::vector<float> c_ref(m() * n());
929 
930   for (size_t iteration = 0; iteration < iterations(); iteration++) {
931     std::generate(a.begin(), a.end(), std::ref(f32rng));
932     std::generate(b.begin(), b.end(), std::ref(f32rng));
933     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
934     std::fill(c.begin(), c.end(), nanf(""));
935     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
936 
937     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
938     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
939 
940     for (size_t i = m(); i < mr(); i++) {
941       for (size_t l = 0; l < k(); l++) {
942         a[l * mr() + i] = a[l * mr() + m() - 1];
943       }
944     }
945 
946     for (size_t i = 0; i < m(); i++) {
947       for (size_t j = 0; j < n(); j++) {
948         for (size_t l = 0; l < k(); l++) {
949           c_ref[i * n() + j] +=
950             a[l * mr() + i] *
951             b[j * k() + l];
952         }
953         c_ref[i * n() + j] += bias[j];
954       }
955     }
956 
957     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
958     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
959     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
960     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
961 
962     // Prepare parameters.
963     xnn_f32_minmax_params params;
964     init_params(&params, c_min, c_max);
965 
966     for (float& c_value : c_ref) {
967       c_value = std::max(std::min(c_value, c_max), c_min);
968     }
969 
970     ppmm_minmax(m(), n(), k() * sizeof(float),
971       a.data(), packed_w.data(),
972       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
973       &params);
974 
975     // Validate micro-kernel outputs.
976     for (size_t i = 0; i < m(); i++) {
977       for (size_t j = 0; j < n(); j++) {
978         ASSERT_NEAR(
979             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
980             c_ref[i * n() + j],
981             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
982             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
983             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
984             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
985       }
986     }
987   }
988 }
989 
Test(xnn_f32_gemm_ukernel_function gemm) const990 void GemmMicrokernelTester::Test(xnn_f32_gemm_ukernel_function gemm) const {
991   ASSERT_LE(m(), mr());
992   ASSERT_GE(a_stride(), k());
993   ASSERT_GE(cm_stride(), n());
994 
995   std::random_device random_device;
996   auto rng = std::mt19937(random_device());
997   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
998 
999   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1000   std::vector<float> b(n() * k());
1001   std::vector<float> bias(n());
1002   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1003   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1004   std::vector<float> c_ref(m() * n());
1005 
1006   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1007     std::generate(a.begin(), a.end(), std::ref(f32rng));
1008     std::generate(b.begin(), b.end(), std::ref(f32rng));
1009     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1010     std::fill(c.begin(), c.end(), nanf(""));
1011     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1012 
1013     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1014     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1015 
1016     for (size_t m_index = 0; m_index < m(); m_index++) {
1017       for (size_t n_index = 0; n_index < n(); n_index++) {
1018         for (size_t k_index = 0; k_index < k(); k_index++) {
1019           ASSERT_LE(n(), packed_n());
1020           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1021           c_ref[m_index * n() + n_index] +=
1022             a[m_index * a_stride() + k_index] *
1023             b[n_index * k() + k_index];
1024         }
1025         c_ref[m_index * n() + n_index] += bias[n_index];
1026       }
1027     }
1028 
1029     gemm(m(), n(), k() * sizeof(float),
1030       a.data(), a_stride() * sizeof(float),
1031       packed_w.data(),
1032       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1033       nullptr);
1034 
1035     // Validate micro-kernel outputs.
1036     for (size_t i = 0; i < m(); i++) {
1037       for (size_t j = 0; j < n(); j++) {
1038         ASSERT_NEAR(
1039             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1040             c_ref[i * n() + j],
1041             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1042             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1043             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1044             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1045       }
1046     }
1047   }
1048 }
1049 
Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const1050 void GemmMicrokernelTester::Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const {
1051   ASSERT_LE(m(), mr());
1052   ASSERT_GE(a_stride(), k());
1053   ASSERT_GE(cm_stride(), n());
1054 
1055   std::random_device random_device;
1056   auto rng = std::mt19937(random_device());
1057   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1058 
1059   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1060   std::vector<float> b(n() * k());
1061   std::vector<float> bias(n());
1062   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1063   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1064   std::vector<float> c_ref(m() * n());
1065 
1066   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1067     std::generate(a.begin(), a.end(), std::ref(f32rng));
1068     std::generate(b.begin(), b.end(), std::ref(f32rng));
1069     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1070     std::fill(c.begin(), c.end(), nanf(""));
1071     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1072 
1073     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1074     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1075 
1076     for (size_t m_index = 0; m_index < m(); m_index++) {
1077       for (size_t n_index = 0; n_index < n(); n_index++) {
1078         for (size_t k_index = 0; k_index < k(); k_index++) {
1079           ASSERT_LE(n(), packed_n());
1080           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1081           c_ref[m_index * n() + n_index] +=
1082             a[m_index * a_stride() + k_index] *
1083             b[n_index * k() + k_index];
1084         }
1085         c_ref[m_index * n() + n_index] = std::max(0.0f, c_ref[m_index * n() + n_index] + bias[n_index]);
1086       }
1087     }
1088 
1089     gemm_relu(m(), n(), k() * sizeof(float),
1090       a.data(), a_stride() * sizeof(float),
1091       packed_w.data(),
1092       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1093       nullptr);
1094 
1095     // Validate micro-kernel outputs.
1096     for (size_t i = 0; i < m(); i++) {
1097       for (size_t j = 0; j < n(); j++) {
1098         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1099             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1100             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1101             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1102         ASSERT_NEAR(
1103             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1104             c_ref[i * n() + j],
1105             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1106             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1107             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1108             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1109       }
1110     }
1111   }
1112 }
1113 
Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1114 void GemmMicrokernelTester::Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1115   ASSERT_LE(m(), mr());
1116   ASSERT_GE(a_stride(), k());
1117   ASSERT_GE(cm_stride(), n());
1118 
1119   std::random_device random_device;
1120   auto rng = std::mt19937(random_device());
1121   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1122 
1123   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1124   std::vector<float> b(n() * k());
1125   std::vector<float> bias(n());
1126   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1127   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1128   std::vector<float> c_ref(m() * n());
1129 
1130   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1131     std::generate(a.begin(), a.end(), std::ref(f32rng));
1132     std::generate(b.begin(), b.end(), std::ref(f32rng));
1133     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1134     std::fill(c.begin(), c.end(), nanf(""));
1135     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1136 
1137     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1138     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1139 
1140     for (size_t m_index = 0; m_index < m(); m_index++) {
1141       for (size_t n_index = 0; n_index < n(); n_index++) {
1142         for (size_t k_index = 0; k_index < k(); k_index++) {
1143           ASSERT_LE(n(), packed_n());
1144           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1145           c_ref[m_index * n() + n_index] +=
1146             a[m_index * a_stride() + k_index] *
1147             b[n_index * k() + k_index];
1148         }
1149         c_ref[m_index * n() + n_index] += bias[n_index];
1150       }
1151     }
1152 
1153     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1154     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1155     const float c_min =
1156         qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1157                     : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1158     const float c_max =
1159         qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1160                       : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1161 
1162     // Prepare parameters.
1163     xnn_f32_minmax_params params;
1164     init_params(&params, c_min, c_max);
1165 
1166     for (size_t m_index = 0; m_index < m(); m_index++) {
1167       for (size_t n_index = 0; n_index < n(); n_index++) {
1168         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1169       }
1170     }
1171 
1172     gemm_minmax(m(), n(), k() * sizeof(float),
1173       a.data(), a_stride() * sizeof(float),
1174       packed_w.data(),
1175       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1176       &params);
1177 
1178     // Validate micro-kernel outputs.
1179     for (size_t i = 0; i < m(); i++) {
1180       for (size_t j = 0; j < n(); j++) {
1181         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1182             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1183             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1184             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1185         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1186             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1187             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1188             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1189         ASSERT_NEAR(
1190             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1191             c_ref[i * n() + j],
1192             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1193             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1194             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1195             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1196       }
1197     }
1198   }
1199 }
1200 
Test(xnn_f32_gemminc_minmax_ukernel_function gemminc,xnn_init_f32_minmax_params_fn init_params) const1201 void GemmMicrokernelTester::Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const {
1202   ASSERT_LE(m(), mr());
1203   ASSERT_GE(a_stride(), k());
1204   ASSERT_GE(cm_stride(), n());
1205 
1206   std::random_device random_device;
1207   auto rng = std::mt19937(random_device());
1208   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1209 
1210   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1211   std::vector<float> b(n() * k());
1212   std::vector<float> bias(n());
1213   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k());  // no packed_n()
1214   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1215   std::vector<float> c_ref(m() * n());
1216   std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
1217 
1218   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1219     std::generate(a.begin(), a.end(), std::ref(f32rng));
1220     std::generate(b.begin(), b.end(), std::ref(f32rng));
1221     std::fill(c.begin(), c.end(), nanf(""));
1222     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1223     std::generate(acc.begin(), acc.end(), std::ref(f32rng));
1224 
1225     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1226     xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data(), nullptr);
1227 
1228     for (size_t m_index = 0; m_index < m(); m_index++) {
1229       for (size_t n_index = 0; n_index < n(); n_index++) {
1230         for (size_t k_index = 0; k_index < k(); k_index++) {
1231           ASSERT_LE(n(), packed_n());
1232           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1233           c_ref[m_index * n() + n_index] +=
1234             a[m_index * a_stride() + k_index] *
1235             b[n_index * k() + k_index];
1236         }
1237         c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
1238       }
1239     }
1240 
1241     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1242     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1243     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1244     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1245 
1246     // Prepare parameters.
1247     xnn_f32_minmax_params params;
1248     init_params(&params, c_min, c_max);
1249 
1250     for (size_t m_index = 0; m_index < m(); m_index++) {
1251       for (size_t n_index = 0; n_index < n(); n_index++) {
1252         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1253       }
1254     }
1255 
1256     gemminc(m(), n(), k() * sizeof(float),
1257       a.data(), a_stride() * sizeof(float),
1258       packed_w.data(),
1259       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1260       acc.data(),
1261       &params);
1262 
1263     // Validate micro-kernel outputs.
1264     for (size_t i = 0; i < m(); i++) {
1265       for (size_t j = 0; j < n(); j++) {
1266         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1267             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1268             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1269             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1270         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1271             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1272             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1273             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1274         ASSERT_NEAR(
1275             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1276             c_ref[i * n() + j],
1277             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1278             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1279             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1280             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1281       }
1282     }
1283   }
1284 }
1285 
Test(xnn_f32_igemm_ukernel_function igemm) const1286 void GemmMicrokernelTester::Test(xnn_f32_igemm_ukernel_function igemm) const {
1287   ASSERT_LE(m(), mr());
1288 
1289   std::random_device random_device;
1290   auto rng = std::mt19937(random_device());
1291   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1292 
1293   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1294   std::vector<float> b(n() * ks() * k());
1295   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1296   std::vector<float> bias(n());
1297   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1298   std::vector<float> c_ref(m() * n());
1299   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1300   std::vector<const float*> im2col(mr() * ks());
1301   std::fill(junk.begin(), junk.end(), nanf(""));
1302 
1303   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1304     std::generate(a.begin(), a.end(), std::ref(f32rng));
1305     std::generate(b.begin(), b.end(), std::ref(f32rng));
1306     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1307     std::fill(c.begin(), c.end(), nanf(""));
1308     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1309 
1310     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1311     xnn_pack_f32_conv_goki_w(
1312       1, n(), ks(), k(), nr(), kr(), sr(),
1313       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1314 
1315     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1316       for (size_t m_index = 0; m_index < mr(); m_index++) {
1317         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1318       }
1319     }
1320     std::shuffle(im2col.begin(), im2col.end(), rng);
1321     if (zero_index() != SIZE_MAX) {
1322       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1323         im2col[ks_index * mr() + zero_index()] = a.data();
1324       }
1325     }
1326     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1327       for (size_t m_index = m(); m_index < mr(); m_index++) {
1328         im2col[ks_index * mr() + m_index] = junk.data();
1329       }
1330     }
1331 
1332     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1333     for (size_t m_index = 0; m_index < m(); m_index++) {
1334       for (size_t n_index = 0; n_index < n(); n_index++) {
1335         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1336           for (size_t k_index = 0; k_index < k(); k_index++) {
1337             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1338             ASSERT_LT(k_index, k());
1339             ASSERT_LT(k_index, a_stride());
1340             if (im2col[ks_index * mr() + m_index] == a.data()) {
1341               c_ref[m_index * n() + n_index] +=
1342                 (im2col[ks_index * mr() + m_index][k_index]) *
1343                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1344             } else {
1345               c_ref[m_index * n() + n_index] +=
1346                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1347                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1348             }
1349           }
1350         }
1351         c_ref[m_index * n() + n_index] += bias[n_index];
1352       }
1353     }
1354 
1355     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1356 
1357     igemm(
1358       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1359       im2col.data(), packed_w.data(),
1360       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1361       a_offset() * sizeof(float), zero_pointer,
1362       nullptr);
1363 
1364     for (size_t i = 0; i < m(); i++) {
1365       for (size_t j = 0; j < n(); j++) {
1366         ASSERT_NEAR(
1367             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1368             c_ref[i * n() + j],
1369             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1370             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1371             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1372             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1373       }
1374     }
1375   }
1376 }
1377 
Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const1378 void GemmMicrokernelTester::Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const {
1379   ASSERT_LE(m(), mr());
1380 
1381   std::random_device random_device;
1382   auto rng = std::mt19937(random_device());
1383   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1384 
1385   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1386   std::vector<float> b(n() * ks() * k());
1387   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1388   std::vector<float> bias(n());
1389   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1390   std::vector<float> c_ref(m() * n());
1391   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1392   std::vector<const float*> im2col(mr() * ks());
1393   std::fill(junk.begin(), junk.end(), nanf(""));
1394 
1395   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1396     std::generate(a.begin(), a.end(), std::ref(f32rng));
1397     std::generate(b.begin(), b.end(), std::ref(f32rng));
1398     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1399     std::fill(c.begin(), c.end(), nanf(""));
1400     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1401 
1402     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1403     xnn_pack_f32_conv_goki_w(
1404       1, n(), ks(), k(), nr(), kr(), sr(),
1405       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1406 
1407     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1408       for (size_t m_index = 0; m_index < mr(); m_index++) {
1409         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1410       }
1411     }
1412     std::shuffle(im2col.begin(), im2col.end(), rng);
1413     if (zero_index() != SIZE_MAX) {
1414       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1415         im2col[ks_index * mr() + zero_index()] = a.data();
1416       }
1417     }
1418     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1419       for (size_t m_index = m(); m_index < mr(); m_index++) {
1420         im2col[ks_index * mr() + m_index] = junk.data();
1421       }
1422     }
1423 
1424     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1425     for (size_t m_index = 0; m_index < m(); m_index++) {
1426       for (size_t n_index = 0; n_index < n(); n_index++) {
1427         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1428           for (size_t k_index = 0; k_index < k(); k_index++) {
1429             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1430             ASSERT_LT(k_index, k());
1431             ASSERT_LT(k_index, a_stride());
1432             if (im2col[ks_index * mr() + m_index] == a.data()) {
1433               c_ref[m_index * n() + n_index] +=
1434                 (im2col[ks_index * mr() + m_index][k_index]) *
1435                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1436             } else {
1437               c_ref[m_index * n() + n_index] +=
1438                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1439                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1440             }
1441           }
1442         }
1443         c_ref[m_index * n() + n_index] = std::max(0.0f, bias[n_index] + c_ref[m_index * n() + n_index]);
1444       }
1445     }
1446 
1447     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1448 
1449     igemm_relu(
1450       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1451       im2col.data(), packed_w.data(),
1452       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1453       a_offset() * sizeof(float), zero_pointer,
1454       nullptr);
1455 
1456     for (size_t i = 0; i < m(); i++) {
1457       for (size_t j = 0; j < n(); j++) {
1458         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1459             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1460             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1461             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1462         ASSERT_NEAR(
1463             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1464             c_ref[i * n() + j],
1465             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1466             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1467             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1468             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1469       }
1470     }
1471   }
1472 }
1473 
Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1474 void GemmMicrokernelTester::Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1475   ASSERT_LE(m(), mr());
1476 
1477   std::random_device random_device;
1478   auto rng = std::mt19937(random_device());
1479   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1480 
1481   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1482   std::vector<float> b(n() * ks() * k());
1483   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1484   std::vector<float> bias(n());
1485   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1486   std::vector<float> c_ref(m() * n());
1487   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1488   std::vector<const float*> im2col(mr() * ks());
1489   std::fill(junk.begin(), junk.end(), nanf(""));
1490 
1491   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1492     std::generate(a.begin(), a.end(), std::ref(f32rng));
1493     std::generate(b.begin(), b.end(), std::ref(f32rng));
1494     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1495     std::fill(c.begin(), c.end(), nanf(""));
1496     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1497 
1498     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1499     xnn_pack_f32_conv_goki_w(
1500       1, n(), ks(), k(), nr(), kr(), sr(),
1501       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1502 
1503     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1504       for (size_t m_index = 0; m_index < mr(); m_index++) {
1505         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1506       }
1507     }
1508     std::shuffle(im2col.begin(), im2col.end(), rng);
1509     if (zero_index() != SIZE_MAX) {
1510       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1511         im2col[ks_index * mr() + zero_index()] = a.data();
1512       }
1513     }
1514     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1515       for (size_t m_index = m(); m_index < mr(); m_index++) {
1516         im2col[ks_index * mr() + m_index] = junk.data();
1517       }
1518     }
1519 
1520     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1521     for (size_t m_index = 0; m_index < m(); m_index++) {
1522       for (size_t n_index = 0; n_index < n(); n_index++) {
1523         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1524           for (size_t k_index = 0; k_index < k(); k_index++) {
1525             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1526             ASSERT_LT(k_index, k());
1527             ASSERT_LT(k_index, a_stride());
1528             if (im2col[ks_index * mr() + m_index] == a.data()) {
1529               c_ref[m_index * n() + n_index] +=
1530                 (im2col[ks_index * mr() + m_index][k_index]) *
1531                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1532             } else {
1533               c_ref[m_index * n() + n_index] +=
1534                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1535                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1536             }
1537           }
1538         }
1539         c_ref[m_index * n() + n_index] += bias[n_index];
1540       }
1541     }
1542 
1543     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1544     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1545     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1546     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1547     for (size_t m_index = 0; m_index < m(); m_index++) {
1548       for (size_t n_index = 0; n_index < n(); n_index++) {
1549         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1550         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1551       }
1552     }
1553 
1554     // Prepare parameters.
1555     xnn_f32_minmax_params params;
1556     init_params(&params, c_min, c_max);
1557 
1558     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1559 
1560     igemm_minmax(
1561       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1562       im2col.data(), packed_w.data(),
1563       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1564       a_offset() * sizeof(float), zero_pointer,
1565       &params);
1566 
1567     for (size_t i = 0; i < m(); i++) {
1568       for (size_t j = 0; j < n(); j++) {
1569         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1570             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1571             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1572             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1573         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1574             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1575             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1576             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1577         ASSERT_NEAR(
1578             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1579             c_ref[i * n() + j],
1580             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1581             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1582             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1583             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1584       }
1585     }
1586   }
1587 }
1588 
1589 #if XNN_PLATFORM_JIT
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_f32_minmax_params_fn init_params) const1590 void GemmMicrokernelTester::Test(xnn_jit_gemm_code_generator_function gemm_generator, xnn_init_f32_minmax_params_fn init_params) const {
1591   ASSERT_LE(m(), mr());
1592   ASSERT_GE(a_stride(), k());
1593   ASSERT_GE(cm_stride(), n());
1594 
1595   std::random_device random_device;
1596   auto rng = std::mt19937(random_device());
1597   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1598 
1599   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1600   std::vector<float> b(n() * k());
1601   std::vector<float> bias(n());
1602   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1603   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1604   std::vector<float> c_ref(m() * n());
1605 
1606   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1607     std::generate(a.begin(), a.end(), std::ref(f32rng));
1608     std::generate(b.begin(), b.end(), std::ref(f32rng));
1609     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1610     std::fill(c.begin(), c.end(), nanf(""));
1611     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1612 
1613     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1614     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1615 
1616     for (size_t m_index = 0; m_index < m(); m_index++) {
1617       for (size_t n_index = 0; n_index < n(); n_index++) {
1618         for (size_t k_index = 0; k_index < k(); k_index++) {
1619           ASSERT_LE(n(), packed_n());
1620           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1621           c_ref[m_index * n() + n_index] +=
1622             a[m_index * a_stride() + k_index] *
1623             b[n_index * k() + k_index];
1624         }
1625         c_ref[m_index * n() + n_index] += bias[n_index];
1626       }
1627     }
1628 
1629     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1630     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1631     const float c_min =
1632         qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1633                     : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1634     const float c_max =
1635         qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1636                       : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1637 
1638     // Prepare parameters.
1639     xnn_f32_minmax_params params;
1640     init_params(&params, c_min, c_max);
1641 
1642     for (size_t m_index = 0; m_index < m(); m_index++) {
1643       for (size_t n_index = 0; n_index < n(); n_index++) {
1644         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1645       }
1646     }
1647 
1648     struct xnn_code_buffer code_buffer;
1649     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1650     jit_gemm_params p = (jit_gemm_params) {
1651       .f32_minmax = {
1652         .min = c_min,
1653         .max = c_max
1654       }
1655     };
1656     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, n(), k() * sizeof(float), &p));
1657     xnn_f32_gemm_minmax_ukernel_function gemm_minmax = reinterpret_cast<xnn_f32_gemm_minmax_ukernel_function>(code_buffer.code);
1658 
1659     gemm_minmax(m(), n(), k() * sizeof(float),
1660       a.data(), a_stride() * sizeof(float),
1661       packed_w.data(),
1662       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1663       &params);
1664 
1665     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1666 
1667     // Validate micro-kernel outputs.
1668     for (size_t i = 0; i < m(); i++) {
1669       for (size_t j = 0; j < n(); j++) {
1670         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1671             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1672             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1673             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1674         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1675             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1676             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1677             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1678         ASSERT_NEAR(
1679             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1680             c_ref[i * n() + j],
1681             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1682             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1683             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1684             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1685       }
1686     }
1687   }
1688 }
1689 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_f32_minmax_params_fn init_params) const1690 void GemmMicrokernelTester::Test(xnn_jit_igemm_code_generator_function igemm_generator, xnn_init_f32_minmax_params_fn init_params) const {
1691   ASSERT_LE(m(), mr());
1692 
1693   std::random_device random_device;
1694   auto rng = std::mt19937(random_device());
1695   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
1696 
1697   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1698   std::vector<float> b(n() * ks() * k());
1699   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1700   std::vector<float> bias(n());
1701   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1702   std::vector<float> c_ref(m() * n());
1703   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1704   std::vector<const float*> im2col(mr() * ks());
1705   std::fill(junk.begin(), junk.end(), nanf(""));
1706 
1707   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1708     std::generate(a.begin(), a.end(), std::ref(f32rng));
1709     std::generate(b.begin(), b.end(), std::ref(f32rng));
1710     std::generate(bias.begin(), bias.end(), std::ref(f32rng));
1711     std::fill(c.begin(), c.end(), nanf(""));
1712     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1713 
1714     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1715     xnn_pack_f32_conv_goki_w(
1716       1, n(), ks(), k(), nr(), kr(), sr(),
1717       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1718 
1719     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1720       for (size_t m_index = 0; m_index < mr(); m_index++) {
1721         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1722       }
1723     }
1724     std::shuffle(im2col.begin(), im2col.end(), rng);
1725     if (zero_index() != SIZE_MAX) {
1726       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1727         im2col[ks_index * mr() + zero_index()] = a.data();
1728       }
1729     }
1730     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1731       for (size_t m_index = m(); m_index < mr(); m_index++) {
1732         im2col[ks_index * mr() + m_index] = junk.data();
1733       }
1734     }
1735 
1736     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1737     for (size_t m_index = 0; m_index < m(); m_index++) {
1738       for (size_t n_index = 0; n_index < n(); n_index++) {
1739         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1740           for (size_t k_index = 0; k_index < k(); k_index++) {
1741             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1742             ASSERT_LT(k_index, k());
1743             ASSERT_LT(k_index, a_stride());
1744             if (im2col[ks_index * mr() + m_index] == a.data()) {
1745               c_ref[m_index * n() + n_index] +=
1746                 (im2col[ks_index * mr() + m_index][k_index]) *
1747                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1748             } else {
1749               c_ref[m_index * n() + n_index] +=
1750                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1751                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1752             }
1753           }
1754         }
1755         c_ref[m_index * n() + n_index] += bias[n_index];
1756       }
1757     }
1758 
1759     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1760     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1761     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1762     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1763     for (size_t m_index = 0; m_index < m(); m_index++) {
1764       for (size_t n_index = 0; n_index < n(); n_index++) {
1765         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1766         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1767       }
1768     }
1769 
1770     // Prepare parameters.
1771     xnn_f32_minmax_params params;
1772     init_params(&params, c_min, c_max);
1773 
1774     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1775 
1776     struct xnn_code_buffer code_buffer;
1777     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1778     jit_gemm_params p = (jit_gemm_params) {
1779       .f32_minmax = {
1780         .min = c_min,
1781         .max = c_max
1782       }
1783     };
1784     ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer,n(), k() * sizeof(float), ks() * mr() * sizeof(void*), &p));
1785     xnn_f32_igemm_minmax_ukernel_function igemm_minmax = reinterpret_cast<xnn_f32_igemm_minmax_ukernel_function>(code_buffer.code);
1786 
1787     igemm_minmax(
1788       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1789       im2col.data(), packed_w.data(),
1790       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1791       a_offset() * sizeof(float), zero_pointer,
1792       &params);
1793 
1794     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1795 
1796     for (size_t i = 0; i < m(); i++) {
1797       for (size_t j = 0; j < n(); j++) {
1798         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1799             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1800             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1801             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1802         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1803             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1804             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1805             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1806         ASSERT_NEAR(
1807             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1808             c_ref[i * n() + j],
1809             std::abs(c_ref[i * n() + j]) * 1.0e-6f)
1810             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1811             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1812             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1813       }
1814     }
1815   }
1816 }
1817 
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qs8_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const1818 void GemmMicrokernelTester::Test(
1819   xnn_jit_gemm_code_generator_function gemm_generator,
1820   xnn_init_qs8_minmax_params_fn init_params,
1821   xnn_qs8_requantize_fn requantize) const
1822 {
1823   ASSERT_LE(m(), mr());
1824 
1825   std::random_device random_device;
1826   auto rng = std::mt19937(random_device());
1827   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
1828   auto i8rng = std::bind(
1829     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
1830     std::ref(rng));
1831   auto w8rng = std::bind(
1832     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
1833     std::ref(rng));
1834 
1835   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
1836   std::vector<int8_t> b(n() * k());
1837   std::vector<int32_t> bias(n());
1838   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
1839   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
1840   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1841   std::vector<int32_t> acc(m() * n());
1842   std::vector<float> scale(n());
1843   std::vector<int8_t> c_ref(m() * n());
1844 
1845   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1846     do {
1847       std::generate(a.begin(), a.end(), std::ref(i8rng));
1848     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
1849     do {
1850       std::generate(b.begin(), b.end(), std::ref(w8rng));
1851     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
1852     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
1853     std::fill(c.begin(), c.end(), 0xA5);
1854 
1855     std::fill(packed_w.begin(), packed_w.end(), 0);
1856     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
1857     if (extended_weights()) {
1858       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
1859         b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
1860     } else {
1861       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
1862         b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
1863     }
1864 
1865     // Compute 32-bit results and output quantization arguments.
1866     std::fill(acc.begin(), acc.end(), 0);
1867     for (size_t m_index = 0; m_index < m(); m_index++) {
1868       for (size_t n_index = 0; n_index < n(); n_index++) {
1869         for (size_t k_index = 0; k_index < k(); k_index++) {
1870           acc[m_index * n() + n_index] +=
1871               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
1872               int32_t(b[n_index * k() + k_index]);
1873         }
1874         acc[m_index * n() + n_index] += bias[n_index];
1875       }
1876     }
1877 
1878     const int8_t c_zero_point = -1;
1879     for (size_t n_index = 0; n_index < n(); n_index++) {
1880       int32_t accumulated_min = acc[n_index];
1881       int32_t accumulated_max = acc[n_index];
1882       for (size_t m_index = 0; m_index < m(); m_index++) {
1883         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
1884         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
1885       }
1886       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
1887       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
1888       scale[n_index] = 1.0f / c_scale;
1889     }
1890 
1891     if (extended_weights()) {
1892       xnn_init_qc8_scale_fp32_params(
1893         n(), nr(),
1894         nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1895         (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
1896     } else {
1897       xnn_init_qc8_scale_fp32_params(
1898         n(), nr(),
1899         nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1900         (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
1901     }
1902 
1903     union xnn_qs8_minmax_params minmax_params;
1904     init_params(&minmax_params,
1905       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
1906 
1907     struct xnn_code_buffer code_buffer;
1908     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1909     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, n(), k(), nullptr));
1910     xnn_qc8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qc8_gemm_minmax_ukernel_function>(code_buffer.code);
1911 
1912     gemm(
1913       m(), n(), k(),
1914       a.data(), a_stride() * sizeof(int8_t),
1915       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
1916       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
1917       &minmax_params);
1918 
1919     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1920 
1921     for (size_t m_index = 0; m_index < m(); m_index++) {
1922       for (size_t n_index = 0; n_index < n(); n_index++) {
1923         c_ref[m_index * n() + n_index] = requantize(
1924           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
1925       }
1926     }
1927 
1928     for (size_t i = 0; i < m(); i++) {
1929       for (size_t j = 0; j < n(); j++) {
1930         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
1931         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
1932         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
1933             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
1934             << " (accumulator = " << acc[i * n() + j]
1935             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
1936             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
1937             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
1938       }
1939     }
1940   }
1941 }
1942 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qs8_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const1943 void GemmMicrokernelTester::Test(
1944   xnn_jit_igemm_code_generator_function igemm_generator,
1945   xnn_init_qs8_minmax_params_fn init_params,
1946   xnn_qs8_requantize_fn requantize) const
1947 {
1948   ASSERT_LE(m(), mr());
1949 
1950   std::random_device random_device;
1951   auto rng = std::mt19937(random_device());
1952   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
1953   auto i8rng = std::bind(
1954     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
1955     std::ref(rng));
1956   auto w8rng = std::bind(
1957     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
1958     std::ref(rng));
1959 
1960   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
1961   std::vector<int8_t> b(n() * ks() * k());
1962   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
1963   std::vector<int32_t> bias(n());
1964   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1965   std::vector<int32_t> acc(m() * n());
1966   std::vector<float> scale(n());
1967   std::vector<int8_t> c_ref(m() * n());
1968   std::vector<int8_t> junk(k() + 8);
1969   std::vector<const int8_t*> im2col(mr() * ks());
1970 
1971   std::fill(junk.begin(), junk.end(), 0xA5);
1972 
1973   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1974     do {
1975       std::generate(a.begin(), a.end(), std::ref(i8rng));
1976     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
1977     do {
1978       std::generate(b.begin(), b.end(), std::ref(w8rng));
1979     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
1980     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
1981     std::fill(c.begin(), c.end(), 0xA5);
1982 
1983     std::fill(packed_w.begin(), packed_w.end(), 0);
1984     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
1985     xnn_pack_qs8_conv_goki_w(
1986       1, n(), ks(), k(), nr(), kr(), sr(),
1987       b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
1988 
1989     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1990       for (size_t m_index = 0; m_index < mr(); m_index++) {
1991         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1992       }
1993     }
1994     std::shuffle(im2col.begin(), im2col.end(), rng);
1995     if (zero_index() != SIZE_MAX) {
1996       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1997         im2col[ks_index * mr() + zero_index()] = a.data();
1998       }
1999     }
2000     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2001       for (size_t m_index = m(); m_index < mr(); m_index++) {
2002         im2col[ks_index * mr() + m_index] = junk.data();
2003       }
2004     }
2005 
2006     // Compute 32-bit results and output quantization arguments.
2007     std::fill(acc.begin(), acc.end(), 0);
2008     for (size_t m_index = 0; m_index < m(); m_index++) {
2009       for (size_t n_index = 0; n_index < n(); n_index++) {
2010         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2011           for (size_t k_index = 0; k_index < k(); k_index++) {
2012             if (im2col[ks_index * mr() + m_index] == a.data()) {
2013               acc[m_index * n() + n_index] +=
2014                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2015                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2016             } else {
2017               acc[m_index * n() + n_index] +=
2018                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2019                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2020             }
2021           }
2022         }
2023         acc[m_index * n() + n_index] += bias[n_index];
2024       }
2025     }
2026 
2027     const int8_t c_zero_point = -1;
2028     for (size_t n_index = 0; n_index < n(); n_index++) {
2029       int32_t accumulated_min = acc[n_index];
2030       int32_t accumulated_max = acc[n_index];
2031       for (size_t m_index = 0; m_index < m(); m_index++) {
2032         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
2033         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
2034       }
2035       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
2036       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
2037       scale[n_index] = 1.0f / c_scale;
2038     }
2039 
2040     xnn_init_qc8_scale_fp32_params(
2041       n(), nr(),
2042       nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
2043       (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
2044 
2045     union xnn_qs8_minmax_params minmax_params;
2046     init_params(&minmax_params,
2047       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2048 
2049     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2050 
2051     struct xnn_code_buffer code_buffer;
2052     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2053     ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer,n(), k(), ks() * mr() * sizeof(void*), nullptr));
2054     xnn_qc8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qc8_igemm_minmax_ukernel_function>(code_buffer.code);
2055 
2056     igemm(
2057       m(), n(), k(), ks() * mr() * sizeof(void*),
2058       im2col.data(), packed_w.data(),
2059       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2060       a_offset() * sizeof(uint8_t), zero_pointer,
2061       &minmax_params);
2062 
2063     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2064 
2065     for (size_t m_index = 0; m_index < m(); m_index++) {
2066       for (size_t n_index = 0; n_index < n(); n_index++) {
2067         c_ref[m_index * n() + n_index] = requantize(
2068           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2069       }
2070     }
2071 
2072     for (size_t i = 0; i < m(); i++) {
2073       for (size_t j = 0; j < n(); j++) {
2074         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2075         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2076         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2077             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2078             << " (accumulator = " << acc[i * n() + j]
2079             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2080             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2081             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
2082       }
2083     }
2084   }
2085 }
2086 
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2087 void GemmMicrokernelTester::Test(
2088   xnn_jit_gemm_code_generator_function gemm_generator,
2089   xnn_init_qs8_conv_minmax_params_fn init_params,
2090   xnn_qs8_requantize_fn requantize) const
2091 {
2092   ASSERT_LE(m(), mr());
2093 
2094   std::random_device random_device;
2095   auto rng = std::mt19937(random_device());
2096   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2097   auto i8rng = std::bind(
2098     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2099     std::ref(rng));
2100   auto w8rng = std::bind(
2101     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2102     std::ref(rng));
2103 
2104   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
2105   std::vector<int8_t> b(n() * k());
2106   std::vector<int32_t> bias(n());
2107   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2108   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
2109   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2110   std::vector<int32_t> acc(m() * n());
2111   std::vector<int8_t> c_ref(m() * n());
2112 
2113   for (size_t iteration = 0; iteration < iterations(); iteration++) {
2114     do {
2115       std::generate(a.begin(), a.end(), std::ref(i8rng));
2116     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2117     do {
2118       std::generate(b.begin(), b.end(), std::ref(w8rng));
2119     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2120     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2121     std::fill(c.begin(), c.end(), 0xA5);
2122 
2123     std::fill(packed_w.begin(), packed_w.end(), 0);
2124     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2125     if (extended_weights()) {
2126       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
2127         b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
2128     } else {
2129       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
2130         b.data(), bias.data(), packed_w.data(), 0, &packing_params);
2131     }
2132 
2133     // Compute 32-bit results and output quantization arguments.
2134     std::fill(acc.begin(), acc.end(), 0);
2135     for (size_t m_index = 0; m_index < m(); m_index++) {
2136       for (size_t n_index = 0; n_index < n(); n_index++) {
2137         for (size_t k_index = 0; k_index < k(); k_index++) {
2138           acc[m_index * n() + n_index] +=
2139               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
2140               int32_t(b[n_index * k() + k_index]);
2141         }
2142         acc[m_index * n() + n_index] += bias[n_index];
2143       }
2144     }
2145 
2146     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2147     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2148     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2149     const int8_t c_zero_point = int8_t(std::max(std::min(
2150       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2151       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2152 
2153     const float requantization_scale = 1.0f / float(c_scale);
2154     union xnn_qs8_conv_minmax_params quantization_params;
2155     init_params(&quantization_params,
2156       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2157 
2158     struct xnn_code_buffer code_buffer;
2159     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2160     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer,n(), k(), nullptr));
2161     xnn_qs8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qs8_gemm_minmax_ukernel_function >(code_buffer.code);
2162 
2163     gemm(
2164       m(), n(), k(),
2165       a.data(), a_stride() * sizeof(int8_t),
2166       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
2167       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2168       &quantization_params);
2169 
2170     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2171 
2172     for (size_t m_index = 0; m_index < m(); m_index++) {
2173       for (size_t n_index = 0; n_index < n(); n_index++) {
2174         c_ref[m_index * n() + n_index] = requantize(
2175           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2176       }
2177     }
2178 
2179     for (size_t i = 0; i < m(); i++) {
2180       for (size_t j = 0; j < n(); j++) {
2181         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2182         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2183         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2184             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
2185             << " (accumulator = " << acc[i * n() + j]
2186             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
2187             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2188             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2189       }
2190     }
2191   }
2192 }
2193 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2194 void GemmMicrokernelTester::Test(
2195   xnn_jit_igemm_code_generator_function igemm_generator,
2196   xnn_init_qs8_conv_minmax_params_fn init_params,
2197   xnn_qs8_requantize_fn requantize) const
2198 {
2199   ASSERT_LE(m(), mr());
2200 
2201   std::random_device random_device;
2202   auto rng = std::mt19937(random_device());
2203   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2204   auto i8rng = std::bind(
2205     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2206     std::ref(rng));
2207   auto w8rng = std::bind(
2208     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2209     std::ref(rng));
2210 
2211   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
2212   std::vector<int8_t> b(n() * ks() * k());
2213   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2214   std::vector<int32_t> bias(n());
2215   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2216   std::vector<int32_t> acc(m() * n());
2217   std::vector<int8_t> c_ref(m() * n());
2218   std::vector<int8_t> junk(k() + 8);
2219   std::vector<const int8_t*> im2col(mr() * ks());
2220 
2221   std::fill(junk.begin(), junk.end(), 0xA5);
2222 
2223   for (size_t iteration = 0; iteration < iterations(); iteration++) {
2224     do {
2225       std::generate(a.begin(), a.end(), std::ref(i8rng));
2226     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2227     do {
2228       std::generate(b.begin(), b.end(), std::ref(w8rng));
2229     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2230     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2231     std::fill(c.begin(), c.end(), 0xA5);
2232 
2233     std::fill(packed_w.begin(), packed_w.end(), 0);
2234     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2235     xnn_pack_qs8_conv_goki_w(
2236       1, n(), ks(), k(), nr(), kr(), sr(),
2237       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
2238 
2239     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2240       for (size_t m_index = 0; m_index < mr(); m_index++) {
2241         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
2242       }
2243     }
2244     std::shuffle(im2col.begin(), im2col.end(), rng);
2245     if (zero_index() != SIZE_MAX) {
2246       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2247         im2col[ks_index * mr() + zero_index()] = a.data();
2248       }
2249     }
2250     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2251       for (size_t m_index = m(); m_index < mr(); m_index++) {
2252         im2col[ks_index * mr() + m_index] = junk.data();
2253       }
2254     }
2255 
2256     // Compute 32-bit results and output quantization arguments.
2257     std::fill(acc.begin(), acc.end(), 0);
2258     for (size_t m_index = 0; m_index < m(); m_index++) {
2259       for (size_t n_index = 0; n_index < n(); n_index++) {
2260         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2261           for (size_t k_index = 0; k_index < k(); k_index++) {
2262             if (im2col[ks_index * mr() + m_index] == a.data()) {
2263               acc[m_index * n() + n_index] +=
2264                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2265                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2266             } else {
2267               acc[m_index * n() + n_index] +=
2268                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2269                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2270             }
2271           }
2272         }
2273         acc[m_index * n() + n_index] += bias[n_index];
2274       }
2275     }
2276 
2277     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2278     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2279     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2280     const uint8_t c_zero_point = uint8_t(std::max(std::min(
2281       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2282       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2283 
2284     const float requantization_scale = 1.0f / float(c_scale);
2285     union xnn_qs8_conv_minmax_params quantization_params;
2286     init_params(&quantization_params,
2287       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2288 
2289     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2290 
2291     struct xnn_code_buffer code_buffer;
2292     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2293     ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer,n(), k(), ks() * mr() * sizeof(void*), nullptr));
2294     xnn_qs8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qs8_igemm_minmax_ukernel_function>(code_buffer.code);
2295 
2296     igemm(
2297       m(), n(), k(), ks() * mr() * sizeof(void*),
2298       im2col.data(), packed_w.data(),
2299       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2300       a_offset() * sizeof(uint8_t), zero_pointer,
2301       &quantization_params);
2302 
2303     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2304 
2305     for (size_t m_index = 0; m_index < m(); m_index++) {
2306       for (size_t n_index = 0; n_index < n(); n_index++) {
2307         c_ref[m_index * n() + n_index] = requantize(
2308           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2309       }
2310     }
2311 
2312     for (size_t i = 0; i < m(); i++) {
2313       for (size_t j = 0; j < n(); j++) {
2314         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2315         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2316         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2317             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2318             << " (accumulator = " << acc[i * n() + j]
2319             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2320             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2321             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2322       }
2323     }
2324   }
2325 }
2326 
2327 #endif  // XNN_PLATFORM_JIT
2328