• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <algorithm>
10 #include <cfloat>
11 #include <chrono>
12 #include <cmath>
13 #include <functional>
14 #include <mutex>
15 #include <random>
16 #include <vector>
17 
18 #include <cpuinfo.h>
19 
20 #include <benchmark/benchmark.h>
21 #ifdef BENCHMARK_RUY
22 #include "tensorflow/lite/experimental/ruy/ruy.h"
23 #endif  // BENCHMARK_RUY
24 #include "bench/gemm.h"
25 #include "bench/utils.h"
26 #include <xnnpack/AlignedAllocator.h>
27 #include <xnnpack/common.h>
28 #include <xnnpack/gemm.h>
29 #include <xnnpack/pack.h>
30 #include <xnnpack/packx.h>
31 #include <xnnpack/params-init.h>
32 #include <xnnpack/params.h>
33 #include <xnnpack/ppmm.h>
34 
35 
GEMMBenchmark(benchmark::State & state,xnn_f32_gemm_ukernel_function gemm,size_t mr,size_t nr,size_t kr,size_t sr,benchmark::utils::IsaCheckFunction isa_check=nullptr)36 static void GEMMBenchmark(benchmark::State& state,
37   xnn_f32_gemm_ukernel_function gemm,
38   size_t mr, size_t nr, size_t kr, size_t sr,
39   benchmark::utils::IsaCheckFunction isa_check = nullptr)
40 {
41   if (!cpuinfo_initialize()) {
42     state.SkipWithError("cpuinfo initialization failed");
43     return;
44   }
45   if (isa_check && !isa_check(state)) {
46     return;
47   }
48 
49   const size_t mc = state.range(0);
50   const size_t nc = state.range(1);
51   const size_t kc = state.range(2);
52 
53   const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
54   const size_t kc_stride = benchmark::utils::RoundUp(kc, kr);
55 
56   std::random_device random_device;
57   auto rng = std::mt19937(random_device());
58   auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
59 
60   std::vector<float> a(mc * kc);
61   std::generate(a.begin(), a.end(), std::ref(f32rng));
62   std::vector<float> k(nc * kc);
63   std::generate(k.begin(), k.end(), std::ref(f32rng));
64   std::vector<float> b(nc);
65   std::generate(b.begin(), b.end(), std::ref(f32rng));
66 
67   const size_t w_elements = nc_stride * kc_stride + nc_stride;
68   const size_t c_elements = mc * nc;
69   const size_t num_buffers = 1 +
70     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
71       sizeof(float) * (w_elements + c_elements));
72 
73   std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
74   std::fill(w.begin(), w.end(), 0.0f);
75   xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data());
76   std::vector<float> c(c_elements * num_buffers);
77   std::fill(c.begin(), c.end(), std::nanf(""));
78 
79   xnn_f32_output_params output_params =
80     xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
81 
82   size_t buffer_index = 0;
83   for (auto _ : state) {
84     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
85     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
86     // - W is not in cache (for any cache level)
87     // - C is not in cache (for any cache level)
88     state.PauseTiming();
89     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
90     buffer_index = (buffer_index + 1) % num_buffers;
91     state.ResumeTiming();
92 
93     for (uint32_t m = 0; m < mc; m += mr) {
94       const uint32_t mb = min(mc - m, mr);
95       gemm(
96         mb, nc, kc * sizeof(float),
97         a.data() + m * kc, kc * sizeof(float),
98         w.data() + buffer_index * nc_stride * (kc_stride + 1),
99         c.data() + (buffer_index * mc + m) * nc, nc * sizeof(float), nr * sizeof(float),
100         &output_params);
101     }
102   }
103 
104   state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
105   state.counters["FLOPS"] = benchmark::Counter(
106     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
107 }
108 
PPMM1PBenchmark(benchmark::State & state,xnn_f32_ppmm_ukernel_function ppmm,xnn_x32_packx_ukernel_function packx,size_t mr,size_t nr,benchmark::utils::IsaCheckFunction isa_check=nullptr)109 static void PPMM1PBenchmark(benchmark::State& state,
110   xnn_f32_ppmm_ukernel_function ppmm,
111   xnn_x32_packx_ukernel_function packx,
112   size_t mr, size_t nr,
113   benchmark::utils::IsaCheckFunction isa_check = nullptr)
114 {
115   if (!cpuinfo_initialize()) {
116     state.SkipWithError("cpuinfo initialization failed");
117     return;
118   }
119   if (isa_check && !isa_check(state)) {
120     return;
121   }
122 
123   const size_t mc = state.range(0);
124   const size_t nc = state.range(1);
125   const size_t kc = state.range(2);
126 
127   const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
128 
129   std::random_device random_device;
130   auto rng = std::mt19937(random_device());
131   auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
132 
133   std::vector<float> a(mc * kc);
134   std::generate(a.begin(), a.end(), std::ref(f32rng));
135   std::vector<float> k(nc * kc);
136   std::generate(k.begin(), k.end(), std::ref(f32rng));
137   std::vector<float> b(nc);
138   std::generate(b.begin(), b.end(), std::ref(f32rng));
139 
140   std::vector<uint32_t, AlignedAllocator<uint32_t, 32>> t(mr * kc);
141 
142   const size_t w_elements = nc_stride * kc + nc_stride;
143   const size_t c_elements = mc * nc;
144   const size_t num_buffers = 1 +
145     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
146       sizeof(float) * (w_elements + c_elements));
147 
148   std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
149   std::fill(w.begin(), w.end(), 0.0f);
150   xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, 1 /* kr */, 1 /* sr */, k.data(), b.data(), w.data());
151   std::vector<float> c(c_elements * num_buffers);
152   std::fill(c.begin(), c.end(), std::nanf(""));
153 
154   xnn_f32_output_params output_params =
155     xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
156 
157   size_t buffer_index = 0;
158   for (auto _ : state) {
159     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
160     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
161     // - W is not in cache (for any cache level)
162     // - C is not in cache (for any cache level)
163     state.PauseTiming();
164     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
165     buffer_index = (buffer_index + 1) % num_buffers;
166     state.ResumeTiming();
167 
168     for (uint32_t m = 0; m < mc; m += mr) {
169       const uint32_t mb = min(mc - m, mr);
170       packx(mb, kc, reinterpret_cast<const uint32_t*>(a.data() + m * kc), kc, t.data());
171       ppmm(
172         mb, nc, kc * sizeof(float),
173         reinterpret_cast<const float*>(t.data()),
174         w.data() + nc_stride * buffer_index * (kc + 1),
175         c.data() + (mc * buffer_index + m) * nc, nc * sizeof(float), nr * sizeof(float),
176         &output_params);
177     }
178   }
179 
180   state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
181   state.counters["FLOPS"] = benchmark::Counter(
182     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
183 }
184 
PPMM2PBenchmark(benchmark::State & state,xnn_f32_ppmm_ukernel_function ppmm,xnn_x32_packx_ukernel_function packx,size_t mr,size_t nr,benchmark::utils::IsaCheckFunction isa_check=nullptr)185 static void PPMM2PBenchmark(benchmark::State& state,
186   xnn_f32_ppmm_ukernel_function ppmm,
187   xnn_x32_packx_ukernel_function packx,
188   size_t mr, size_t nr,
189   benchmark::utils::IsaCheckFunction isa_check = nullptr)
190 {
191   if (!cpuinfo_initialize()) {
192     state.SkipWithError("cpuinfo initialization failed");
193     return;
194   }
195   if (isa_check && !isa_check(state)) {
196     return;
197   }
198 
199   const size_t mc = state.range(0);
200   const size_t nc = state.range(1);
201   const size_t kc = state.range(2);
202 
203   const size_t mc_stride = benchmark::utils::RoundUp(mc, mr);
204   const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
205 
206   std::random_device random_device;
207   auto rng = std::mt19937(random_device());
208   auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
209 
210   std::vector<float> a(mc * kc);
211   std::generate(a.begin(), a.end(), std::ref(f32rng));
212   std::vector<float> k(nc * kc);
213   std::generate(k.begin(), k.end(), std::ref(f32rng));
214   std::vector<float> b(nc);
215   std::generate(b.begin(), b.end(), std::ref(f32rng));
216 
217   std::vector<uint32_t, AlignedAllocator<uint32_t, 32>> t(mc_stride * kc);
218 
219   const size_t w_elements = nc_stride * kc + nc_stride;
220   const size_t c_elements = mc * nc;
221   const size_t num_buffers = 1 +
222     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
223       sizeof(float) * (w_elements + c_elements));
224 
225   std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
226   std::fill(w.begin(), w.end(), 0.0f);
227   xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, 1 /* kr */, 1 /* sr */, k.data(), b.data(), w.data());
228   std::vector<float> c(c_elements * num_buffers);
229   std::fill(c.begin(), c.end(), std::nanf(""));
230 
231   xnn_f32_output_params output_params =
232     xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
233 
234   size_t buffer_index = 0;
235   for (auto _ : state) {
236     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
237     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
238     // - W is not in cache (for any cache level)
239     // - C is not in cache (for any cache level)
240     state.PauseTiming();
241     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
242     buffer_index = (buffer_index + 1) % num_buffers;
243     state.ResumeTiming();
244 
245     for (uint32_t m = 0; m < mc; m += mr) {
246       const uint32_t mb = min(mc - m, mr);
247       packx(mb, kc, reinterpret_cast<const uint32_t*>(a.data() + m * kc), kc, t.data() + m * kc);
248     }
249     for (uint32_t m = 0; m < mc; m += mr) {
250       const uint32_t mb = min(mc - m, mr);
251       ppmm(
252         mb, nc, kc * sizeof(float),
253         reinterpret_cast<const float*>(t.data() + m * kc),
254         w.data() + nc_stride * buffer_index * (kc + 1),
255         c.data() + (mc * buffer_index + m) * nc, nc * sizeof(float), nr * sizeof(float),
256         &output_params);
257     }
258   }
259 
260   state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
261   state.counters["FLOPS"] = benchmark::Counter(
262     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
263 }
264 
265 #ifdef BENCHMARK_RUY
RuyBenchmark(benchmark::State & state,uint32_t threads)266 static void RuyBenchmark(benchmark::State& state, uint32_t threads)
267 {
268   std::random_device random_device;
269   auto rng = std::mt19937(random_device());
270   auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
271 
272   const size_t mc = state.range(0);
273   const size_t nc = state.range(1);
274   const size_t kc = state.range(2);
275 
276   const size_t num_buffers = 1 +
277     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
278       sizeof(float) * (nc * (mc + kc + 1)));
279 
280   std::vector<float> a(mc * kc);
281   std::generate(a.begin(), a.end(), std::ref(f32rng));
282   std::vector<float> k(num_buffers * nc * kc);
283   std::generate(k.begin(), k.end(), std::ref(f32rng));
284   std::vector<float> b(num_buffers * nc);
285   std::generate(b.begin(), b.end(), std::ref(f32rng));
286   std::vector<float> c(num_buffers * nc * mc);
287   std::fill(c.begin(), c.end(), std::nanf(""));
288 
289   // Note: context must be static to avoid the cost of re-creating it for each benchmark.
290   static ruy::Context context;
291   context.max_num_threads = threads;
292 
293   ruy::Matrix<float> ruy_a;
294   ruy::MakeSimpleLayout(nc, kc, ruy::Order::kRowMajor, &ruy_a.layout);
295   ruy::Matrix<float> ruy_b;
296   ruy::MakeSimpleLayout(kc, mc, ruy::Order::kColMajor, &ruy_b.layout);
297   ruy_b.data = a.data();
298   ruy::Matrix<float> ruy_c;
299   ruy::MakeSimpleLayout(nc, mc, ruy::Order::kColMajor, &ruy_c.layout);
300 
301   ruy::BasicSpec<float, float> spec;
302 
303   // ruy::Context uses deferred initialization, which affects percieved GEMM performance. Initialization happens during
304   // the first GEMM calls, and per Benoit Jacob it takes up to ~250 milliseconds for performance to stabilize.
305   // Thus, on the first benchmark, we compute GEMM for 500 milliseconds (to be safe) without recording performance, and
306   // keep the ruy::Context object initialized (by being static) between subsequent benchmarks.
307   static std::once_flag warmup;
308   std::call_once(warmup, [&](){
309     auto start = std::chrono::steady_clock::now();
310     do {
311       ruy_a.data = k.data();
312       ruy_c.data = c.data();
313       spec.bias = b.data();
314 
315       ruy::Mul<ruy::kAllPaths>(ruy_a, ruy_b, spec, &context, &ruy_c);
316     } while (std::chrono::duration<double>(std::chrono::steady_clock::now() - start).count() < 0.5);
317   });
318 
319   size_t buffer_index = 0;
320   for (auto _ : state) {
321     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
322     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
323     // - K is not in cache (for any cache level)
324     // - B is not in cache (for any cache level)
325     // - C is not in cache (for any cache level)
326     state.PauseTiming();
327     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
328     buffer_index = (buffer_index + 1) % num_buffers;
329     state.ResumeTiming();
330 
331     ruy_a.data = k.data() + buffer_index * nc * kc;
332     ruy_c.data = c.data() + buffer_index * mc * nc;
333     spec.bias = b.data() + buffer_index * nc;
334 
335     ruy::Mul<ruy::kAllPaths>(ruy_a, ruy_b, spec, &context, &ruy_c);
336   }
337 
338   state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
339   state.counters["FLOPS"] = benchmark::Counter(
340     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
341 }
342 
ruy_st(benchmark::State & state,const char * net)343 static void ruy_st(benchmark::State& state, const char* net)
344 {
345   RuyBenchmark(state, 1);
346 }
347 #endif  // BENCHMARK_RUY
348 
349 
350 #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
f32_gemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)351   static void f32_gemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
352     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53, 1, 12, 1, 1);
353   }
f32_gemm_1x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)354   static void f32_gemm_1x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
355     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53, 1, 8, 1, 1);
356   }
f32_gemm_1x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)357   static void f32_gemm_1x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
358     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57, 1, 8, 1, 1);
359   }
f32_gemm_1x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)360   static void f32_gemm_1x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
361     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75, 1, 8, 1, 1);
362   }
f32_gemm_4x12__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)363   static void f32_gemm_4x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
364     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53, 4, 12, 1, 1);
365   }
f32_gemm_4x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)366   static void f32_gemm_4x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
367     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53, 4, 8, 1, 1);
368   }
f32_gemm_4x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)369   static void f32_gemm_4x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
370     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57, 4, 8, 1, 1);
371   }
f32_gemm_4x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)372   static void f32_gemm_4x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
373     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75, 4, 8, 1, 1);
374   }
f32_gemm_4x8__aarch64_neonfma_ld64(benchmark::State & state,const char * net)375   static void f32_gemm_4x8__aarch64_neonfma_ld64(benchmark::State& state, const char* net) {
376     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64, 4, 8, 1, 1);
377   }
f32_gemm_4x8__aarch64_neonfma_ld128(benchmark::State & state,const char * net)378   static void f32_gemm_4x8__aarch64_neonfma_ld128(benchmark::State& state, const char* net) {
379     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128, 4, 8, 1, 1);
380   }
f32_gemm_5x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)381   static void f32_gemm_5x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
382     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a57, 5, 8, 1, 1);
383   }
f32_gemm_5x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)384   static void f32_gemm_5x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
385     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75, 5, 8, 1, 1);
386   }
f32_gemm_6x8__aarch64_neonfma_ld64(benchmark::State & state,const char * net)387   static void f32_gemm_6x8__aarch64_neonfma_ld64(benchmark::State& state, const char* net) {
388     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64, 6, 8, 1, 1);
389   }
f32_gemm_6x8__aarch64_neonfma_ld128(benchmark::State & state,const char * net)390   static void f32_gemm_6x8__aarch64_neonfma_ld128(benchmark::State& state, const char* net) {
391     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128, 6, 8, 1, 1);
392   }
f32_gemm_6x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)393   static void f32_gemm_6x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
394     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53, 6, 8, 1, 1);
395   }
f32_gemm_6x8__aarch64_neonfma_cortex_a73(benchmark::State & state,const char * net)396   static void f32_gemm_6x8__aarch64_neonfma_cortex_a73(benchmark::State& state, const char* net) {
397     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73, 6, 8, 1, 1);
398   }
f32_gemm_6x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)399   static void f32_gemm_6x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
400     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57, 6, 8, 1, 1);
401   }
f32_gemm_6x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)402   static void f32_gemm_6x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
403     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75, 6, 8, 1, 1);
404   }
f32_gemm_1x8__neonfma_lane_ld64(benchmark::State & state,const char * net)405   static void f32_gemm_1x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
406     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64, 1, 8, 1, 1);
407   }
f32_gemm_4x8__neonfma_lane_ld64(benchmark::State & state,const char * net)408   static void f32_gemm_4x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
409     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_lane_ld64, 4, 8, 1, 1);
410   }
f32_gemm_4x8__neonfma_lane_ld128(benchmark::State & state,const char * net)411   static void f32_gemm_4x8__neonfma_lane_ld128(benchmark::State& state, const char* net) {
412     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_lane_ld128, 4, 8, 1, 1);
413   }
f32_gemm_5x8__neonfma_lane_ld64(benchmark::State & state,const char * net)414   static void f32_gemm_5x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
415     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__neonfma_lane_ld64, 5, 8, 1, 1);
416   }
f32_gemm_6x8__neonfma_lane_ld64(benchmark::State & state,const char * net)417   static void f32_gemm_6x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
418     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64, 6, 8, 1, 1);
419   }
f32_gemm_6x8__neonfma_lane_ld128(benchmark::State & state,const char * net)420   static void f32_gemm_6x8__neonfma_lane_ld128(benchmark::State& state, const char* net) {
421     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld128, 6, 8, 1, 1);
422   }
423   BENCHMARK_GEMM(f32_gemm_1x12__aarch64_neonfma_cortex_a53)
BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a53)424   BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a53)
425   BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a57)
426   BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a75)
427   BENCHMARK_GEMM(f32_gemm_4x12__aarch64_neonfma_cortex_a53)
428   BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a53)
429   BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a57)
430   BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a75)
431   BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_ld128)
432   BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_ld64)
433   BENCHMARK_GEMM(f32_gemm_5x8__aarch64_neonfma_cortex_a57)
434   BENCHMARK_GEMM(f32_gemm_5x8__aarch64_neonfma_cortex_a75)
435   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a53)
436   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a73)
437   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a57)
438   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a75)
439   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_ld64)
440   BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_ld128)
441   BENCHMARK_GEMM(f32_gemm_1x8__neonfma_lane_ld64)
442   BENCHMARK_GEMM(f32_gemm_4x8__neonfma_lane_ld64)
443   BENCHMARK_GEMM(f32_gemm_4x8__neonfma_lane_ld128)
444   BENCHMARK_GEMM(f32_gemm_5x8__neonfma_lane_ld64)
445   BENCHMARK_GEMM(f32_gemm_6x8__neonfma_lane_ld64)
446   BENCHMARK_GEMM(f32_gemm_6x8__neonfma_lane_ld128)
447 #endif  // XNN_ARCH_ARM64
448 
449 #if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY
450   static void f32_gemm_4x8__aarch32_neon_ld64(benchmark::State& state, const char* net) {
451     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEON);
452   }
f32_gemm_4x8__aarch32_neon_cortex_a53(benchmark::State & state,const char * net)453   static void f32_gemm_4x8__aarch32_neon_cortex_a53(benchmark::State& state, const char* net) {
454     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53, 4, 8, 1, 1, benchmark::utils::CheckNEON);
455   }
f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State & state,const char * net)456   static void f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, const char* net) {
457     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75, 4, 8, 1, 1, benchmark::utils::CheckNEON);
458   }
f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State & state,const char * net)459   static void f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, const char* net) {
460     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, 4, 8, 1, 1, benchmark::utils::CheckNEON);
461   }
462 
463   BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_ld64)
BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a53)464   BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a53)
465   BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a75)
466   BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_pld_cortex_a75)
467 #endif  // XNN_ARCH_ARM
468 
469 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
470   static void f32_gemm_1x8__neon_lane_ld64(benchmark::State& state, const char* net) {
471     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, 1, 8, 1, 1, benchmark::utils::CheckNEON);
472   }
f32_gemm_4x8__neon_lane_ld64(benchmark::State & state,const char * net)473   static void f32_gemm_4x8__neon_lane_ld64(benchmark::State& state, const char* net) {
474     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neon_lane_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEON);
475   }
f32_gemm_4x8__neon_lane_ld128(benchmark::State & state,const char * net)476   static void f32_gemm_4x8__neon_lane_ld128(benchmark::State& state, const char* net) {
477     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neon_lane_ld128, 4, 8, 1, 1, benchmark::utils::CheckNEON);
478   }
f32_gemm_5x8__neon_lane_ld64(benchmark::State & state,const char * net)479   static void f32_gemm_5x8__neon_lane_ld64(benchmark::State& state, const char* net) {
480     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__neon_lane_ld64, 5, 8, 1, 1, benchmark::utils::CheckNEON);
481   }
f32_gemm_6x8__neon_lane_ld64(benchmark::State & state,const char * net)482   static void f32_gemm_6x8__neon_lane_ld64(benchmark::State& state, const char* net) {
483     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neon_lane_ld64, 6, 8, 1, 1, benchmark::utils::CheckNEON);
484   }
f32_gemm_6x8__neon_lane_ld128(benchmark::State & state,const char * net)485   static void f32_gemm_6x8__neon_lane_ld128(benchmark::State& state, const char* net) {
486     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neon_lane_ld128, 6, 8, 1, 1, benchmark::utils::CheckNEON);
487   }
f32_gemm_1x8__neonfma_dup_ld64(benchmark::State & state,const char * net)488   static void f32_gemm_1x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
489     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neonfma_dup_ld64, 1, 8, 1, 1, benchmark::utils::CheckNEONFMA);
490   }
f32_gemm_4x8__neonfma_dup_ld64(benchmark::State & state,const char * net)491   static void f32_gemm_4x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
492     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_dup_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEONFMA);
493   }
f32_gemm_4x8__neonfma_dup_ld128(benchmark::State & state,const char * net)494   static void f32_gemm_4x8__neonfma_dup_ld128(benchmark::State& state, const char* net) {
495     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_dup_ld128, 4, 8, 1, 1, benchmark::utils::CheckNEONFMA);
496   }
f32_gemm_6x8__neonfma_dup_ld64(benchmark::State & state,const char * net)497   static void f32_gemm_6x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
498     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_dup_ld64, 6, 8, 1, 1, benchmark::utils::CheckNEONFMA);
499   }
f32_gemm_6x8__neonfma_dup_ld128(benchmark::State & state,const char * net)500   static void f32_gemm_6x8__neonfma_dup_ld128(benchmark::State& state, const char* net) {
501     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_dup_ld128, 6, 8, 1, 1, benchmark::utils::CheckNEONFMA);
502   }
f32_gemm_1x8s4__neon(benchmark::State & state,const char * net)503   static void f32_gemm_1x8s4__neon(benchmark::State& state, const char* net) {
504     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__neon, 1, 8, 1, 4, benchmark::utils::CheckNEON);
505   }
f32_gemm_1x8s4__neonfma(benchmark::State & state,const char * net)506   static void f32_gemm_1x8s4__neonfma(benchmark::State& state, const char* net) {
507     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__neonfma, 1, 8, 1, 4, benchmark::utils::CheckNEONFMA);
508   }
f32_gemm_4x8s4__neon(benchmark::State & state,const char * net)509   static void f32_gemm_4x8s4__neon(benchmark::State& state, const char* net) {
510     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__neon, 4, 8, 1, 4, benchmark::utils::CheckNEON);
511   }
f32_gemm_4x8s4__neonfma(benchmark::State & state,const char * net)512   static void f32_gemm_4x8s4__neonfma(benchmark::State& state, const char* net) {
513     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__neonfma, 4, 8, 1, 4, benchmark::utils::CheckNEONFMA);
514   }
f32_gemm_6x8s4__neon(benchmark::State & state,const char * net)515   static void f32_gemm_6x8s4__neon(benchmark::State& state, const char* net) {
516     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__neon, 6, 8, 1, 4, benchmark::utils::CheckNEON);
517   }
f32_gemm_6x8s4__neonfma(benchmark::State & state,const char * net)518   static void f32_gemm_6x8s4__neonfma(benchmark::State& state, const char* net) {
519     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__neonfma, 6, 8, 1, 4, benchmark::utils::CheckNEONFMA);
520   }
f32_gemm_8x8s4__neon(benchmark::State & state,const char * net)521   static void f32_gemm_8x8s4__neon(benchmark::State& state, const char* net) {
522     GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8s4__neon, 8, 8, 1, 4, benchmark::utils::CheckNEON);
523   }
f32_gemm_8x8s4__neonfma(benchmark::State & state,const char * net)524   static void f32_gemm_8x8s4__neonfma(benchmark::State& state, const char* net) {
525     GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8s4__neonfma, 8, 8, 1, 4, benchmark::utils::CheckNEONFMA);
526   }
f32_ppmm_4x8_unipass__neonfma(benchmark::State & state,const char * net)527   static void f32_ppmm_4x8_unipass__neonfma(benchmark::State& state, const char* net) {
528     PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__neonfma, xnn_x32_packx_ukernel_4x__neon_st4, 4, 8, benchmark::utils::CheckNEONFMA);
529   }
f32_ppmm_4x8_twopass__neonfma(benchmark::State & state,const char * net)530   static void f32_ppmm_4x8_twopass__neonfma(benchmark::State& state, const char* net) {
531     PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__neonfma, xnn_x32_packx_ukernel_4x__neon_st4, 4, 8, benchmark::utils::CheckNEONFMA);
532   }
533 
534   BENCHMARK_GEMM(f32_gemm_1x8__neon_lane_ld64)
BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld64)535   BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld64)
536   BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld128)
537   BENCHMARK_GEMM(f32_gemm_5x8__neon_lane_ld64)
538   BENCHMARK_GEMM(f32_gemm_6x8__neon_lane_ld64)
539   BENCHMARK_GEMM(f32_gemm_6x8__neon_lane_ld128)
540   BENCHMARK_GEMM(f32_gemm_1x8__neonfma_dup_ld64)
541   BENCHMARK_GEMM(f32_gemm_4x8__neonfma_dup_ld128)
542   BENCHMARK_GEMM(f32_gemm_4x8__neonfma_dup_ld64)
543   BENCHMARK_GEMM(f32_gemm_6x8__neonfma_dup_ld64)
544   BENCHMARK_GEMM(f32_gemm_6x8__neonfma_dup_ld128)
545   BENCHMARK_GEMM(f32_gemm_1x8s4__neon)
546   BENCHMARK_GEMM(f32_gemm_1x8s4__neonfma)
547   BENCHMARK_GEMM(f32_gemm_4x8s4__neon)
548   BENCHMARK_GEMM(f32_gemm_4x8s4__neonfma)
549   BENCHMARK_GEMM(f32_gemm_6x8s4__neon)
550   BENCHMARK_GEMM(f32_gemm_6x8s4__neonfma)
551   BENCHMARK_GEMM(f32_gemm_8x8s4__neon)
552   BENCHMARK_GEMM(f32_gemm_8x8s4__neonfma)
553   BENCHMARK_GEMM(f32_ppmm_4x8_unipass__neonfma)
554   BENCHMARK_GEMM(f32_ppmm_4x8_twopass__neonfma)
555 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
556 
557 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
558   static void f32_gemm_1x8__sse_load1(benchmark::State& state, const char* net) {
559     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__sse_load1, 1, 8, 1, 1);
560   }
f32_gemm_4x8__sse_load1(benchmark::State & state,const char * net)561   static void f32_gemm_4x8__sse_load1(benchmark::State& state, const char* net) {
562     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__sse_load1, 4, 8, 1, 1);
563   }
564 
f32_gemm_1x8__sse_dup(benchmark::State & state,const char * net)565   static void f32_gemm_1x8__sse_dup(benchmark::State& state, const char* net) {
566     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__sse_dup, 1, 8, 1, 1);
567   }
f32_gemm_4x8__sse_dup(benchmark::State & state,const char * net)568   static void f32_gemm_4x8__sse_dup(benchmark::State& state, const char* net) {
569     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__sse_dup, 4, 8, 1, 1);
570   }
571 
f32_gemm_1x8s4__sse(benchmark::State & state,const char * net)572   static void f32_gemm_1x8s4__sse(benchmark::State& state, const char* net) {
573     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__sse, 1, 8, 1, 4);
574   }
f32_gemm_4x8s4__sse(benchmark::State & state,const char * net)575   static void f32_gemm_4x8s4__sse(benchmark::State& state, const char* net) {
576     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__sse, 4, 8, 1, 4);
577   }
578 
f32_ppmm_4x8_unipass__sse(benchmark::State & state,const char * net)579   static void f32_ppmm_4x8_unipass__sse(benchmark::State& state, const char* net) {
580     PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__sse, xnn_x32_packx_ukernel_4x__sse, 4, 8);
581   }
f32_ppmm_4x8_twopass__sse(benchmark::State & state,const char * net)582   static void f32_ppmm_4x8_twopass__sse(benchmark::State& state, const char* net) {
583     PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__sse, xnn_x32_packx_ukernel_4x__sse, 4, 8);
584   }
585 
f32_gemm_1x8__avx_broadcast(benchmark::State & state,const char * net)586   static void f32_gemm_1x8__avx_broadcast(benchmark::State& state, const char* net) {
587     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__avx_broadcast, 1, 8, 1, 1, benchmark::utils::CheckAVX);
588   }
f32_gemm_4x8__avx_broadcast(benchmark::State & state,const char * net)589   static void f32_gemm_4x8__avx_broadcast(benchmark::State& state, const char* net) {
590     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__avx_broadcast, 4, 8, 1, 1, benchmark::utils::CheckAVX);
591   }
f32_gemm_5x8__avx_broadcast(benchmark::State & state,const char * net)592   static void f32_gemm_5x8__avx_broadcast(benchmark::State& state, const char* net) {
593     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__avx_broadcast, 5, 8, 1, 1, benchmark::utils::CheckAVX);
594   }
f32_gemm_6x8__avx_broadcast(benchmark::State & state,const char * net)595   static void f32_gemm_6x8__avx_broadcast(benchmark::State& state, const char* net) {
596     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__avx_broadcast, 6, 8, 1, 1, benchmark::utils::CheckAVX);
597   }
f32_gemm_7x8__avx_broadcast(benchmark::State & state,const char * net)598   static void f32_gemm_7x8__avx_broadcast(benchmark::State& state, const char* net) {
599     GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x8__avx_broadcast, 7, 8, 1, 1, benchmark::utils::CheckAVX);
600   }
f32_gemm_1x16__avx_broadcast(benchmark::State & state,const char * net)601   static void f32_gemm_1x16__avx_broadcast(benchmark::State& state, const char* net) {
602     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__avx_broadcast, 1, 16, 1, 1, benchmark::utils::CheckAVX);
603   }
f32_gemm_3x16__avx_broadcast(benchmark::State & state,const char * net)604   static void f32_gemm_3x16__avx_broadcast(benchmark::State& state, const char* net) {
605     GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16__avx_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX);
606   }
f32_gemm_4x16__avx_broadcast(benchmark::State & state,const char * net)607   static void f32_gemm_4x16__avx_broadcast(benchmark::State& state, const char* net) {
608     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__avx_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX);
609   }
f32_gemm_5x16__avx_broadcast(benchmark::State & state,const char * net)610   static void f32_gemm_5x16__avx_broadcast(benchmark::State& state, const char* net) {
611     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__avx_broadcast, 5, 16, 1, 1, benchmark::utils::CheckAVX);
612   }
613 
f32_gemm_1x8__fma3_broadcast(benchmark::State & state,const char * net)614   static void f32_gemm_1x8__fma3_broadcast(benchmark::State& state, const char* net) {
615     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__fma3_broadcast, 1, 8, 1, 1, benchmark::utils::CheckFMA3);
616   }
f32_gemm_4x8__fma3_broadcast(benchmark::State & state,const char * net)617   static void f32_gemm_4x8__fma3_broadcast(benchmark::State& state, const char* net) {
618     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__fma3_broadcast, 4, 8, 1, 1, benchmark::utils::CheckFMA3);
619   }
f32_gemm_5x8__fma3_broadcast(benchmark::State & state,const char * net)620   static void f32_gemm_5x8__fma3_broadcast(benchmark::State& state, const char* net) {
621     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__fma3_broadcast, 5, 8, 1, 1, benchmark::utils::CheckFMA3);
622   }
f32_gemm_6x8__fma3_broadcast(benchmark::State & state,const char * net)623   static void f32_gemm_6x8__fma3_broadcast(benchmark::State& state, const char* net) {
624     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__fma3_broadcast, 6, 8, 1, 1, benchmark::utils::CheckFMA3);
625   }
f32_gemm_7x8__fma3_broadcast(benchmark::State & state,const char * net)626   static void f32_gemm_7x8__fma3_broadcast(benchmark::State& state, const char* net) {
627     GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x8__fma3_broadcast, 7, 8, 1, 1, benchmark::utils::CheckFMA3);
628   }
f32_gemm_8x8__fma3_broadcast(benchmark::State & state,const char * net)629   static void f32_gemm_8x8__fma3_broadcast(benchmark::State& state, const char* net) {
630     GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8__fma3_broadcast, 8, 8, 1, 1, benchmark::utils::CheckFMA3);
631   }
f32_gemm_1x16__fma3_broadcast(benchmark::State & state,const char * net)632   static void f32_gemm_1x16__fma3_broadcast(benchmark::State& state, const char* net) {
633     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__fma3_broadcast, 1, 16, 1, 1, benchmark::utils::CheckFMA3);
634   }
f32_gemm_3x16__fma3_broadcast(benchmark::State & state,const char * net)635   static void f32_gemm_3x16__fma3_broadcast(benchmark::State& state, const char* net) {
636     GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16__fma3_broadcast, 4, 16, 1, 1, benchmark::utils::CheckFMA3);
637   }
f32_gemm_4x16__fma3_broadcast(benchmark::State & state,const char * net)638   static void f32_gemm_4x16__fma3_broadcast(benchmark::State& state, const char* net) {
639     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__fma3_broadcast, 4, 16, 1, 1, benchmark::utils::CheckFMA3);
640   }
f32_gemm_5x16__fma3_broadcast(benchmark::State & state,const char * net)641   static void f32_gemm_5x16__fma3_broadcast(benchmark::State& state, const char* net) {
642     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__fma3_broadcast, 5, 16, 1, 1, benchmark::utils::CheckFMA3);
643   }
644 
f32_gemm_1x16s4__fma3_broadcast(benchmark::State & state,const char * net)645   static void f32_gemm_1x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
646     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16s4__fma3_broadcast, 1, 16, 1, 4, benchmark::utils::CheckFMA3);
647   }
f32_gemm_3x16s4__fma3_broadcast(benchmark::State & state,const char * net)648   static void f32_gemm_3x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
649     GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16s4__fma3_broadcast, 4, 16, 1, 4, benchmark::utils::CheckFMA3);
650   }
f32_gemm_4x16s4__fma3_broadcast(benchmark::State & state,const char * net)651   static void f32_gemm_4x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
652     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16s4__fma3_broadcast, 4, 16, 1, 4, benchmark::utils::CheckFMA3);
653   }
f32_gemm_5x16s4__fma3_broadcast(benchmark::State & state,const char * net)654   static void f32_gemm_5x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
655     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16s4__fma3_broadcast, 5, 16, 1, 4, benchmark::utils::CheckFMA3);
656   }
657 
f32_gemm_1x16__avx512f_broadcast(benchmark::State & state,const char * net)658   static void f32_gemm_1x16__avx512f_broadcast(benchmark::State& state, const char* net) {
659     GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__avx512f_broadcast, 1, 16, 1, 1, benchmark::utils::CheckAVX512F);
660   }
f32_gemm_4x16__avx512f_broadcast(benchmark::State & state,const char * net)661   static void f32_gemm_4x16__avx512f_broadcast(benchmark::State& state, const char* net) {
662     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__avx512f_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX512F);
663   }
f32_gemm_5x16__avx512f_broadcast(benchmark::State & state,const char * net)664   static void f32_gemm_5x16__avx512f_broadcast(benchmark::State& state, const char* net) {
665     GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__avx512f_broadcast, 5, 16, 1, 1, benchmark::utils::CheckAVX512F);
666   }
f32_gemm_6x16__avx512f_broadcast(benchmark::State & state,const char * net)667   static void f32_gemm_6x16__avx512f_broadcast(benchmark::State& state, const char* net) {
668     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x16__avx512f_broadcast, 6, 16, 1, 1, benchmark::utils::CheckAVX512F);
669   }
f32_gemm_7x16__avx512f_broadcast(benchmark::State & state,const char * net)670   static void f32_gemm_7x16__avx512f_broadcast(benchmark::State& state, const char* net) {
671     GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x16__avx512f_broadcast, 7, 16, 1, 1, benchmark::utils::CheckAVX512F);
672   }
f32_gemm_8x16__avx512f_broadcast(benchmark::State & state,const char * net)673   static void f32_gemm_8x16__avx512f_broadcast(benchmark::State& state, const char* net) {
674     GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x16__avx512f_broadcast, 8, 16, 1, 1, benchmark::utils::CheckAVX512F);
675   }
676 
677   BENCHMARK_GEMM(f32_gemm_1x8__sse_load1)
BENCHMARK_GEMM(f32_gemm_4x8__sse_load1)678   BENCHMARK_GEMM(f32_gemm_4x8__sse_load1)
679 
680   BENCHMARK_GEMM(f32_gemm_1x8__sse_dup)
681   BENCHMARK_GEMM(f32_gemm_4x8__sse_dup)
682 
683   BENCHMARK_GEMM(f32_gemm_1x8s4__sse)
684   BENCHMARK_GEMM(f32_gemm_4x8s4__sse)
685 
686   BENCHMARK_GEMM(f32_ppmm_4x8_unipass__sse)
687   BENCHMARK_GEMM(f32_ppmm_4x8_twopass__sse)
688 
689   BENCHMARK_GEMM(f32_gemm_1x8__avx_broadcast)
690   BENCHMARK_GEMM(f32_gemm_4x8__avx_broadcast)
691   BENCHMARK_GEMM(f32_gemm_5x8__avx_broadcast)
692   BENCHMARK_GEMM(f32_gemm_6x8__avx_broadcast)
693   BENCHMARK_GEMM(f32_gemm_7x8__avx_broadcast)
694   BENCHMARK_GEMM(f32_gemm_1x16__avx_broadcast)
695   BENCHMARK_GEMM(f32_gemm_3x16__avx_broadcast)
696   BENCHMARK_GEMM(f32_gemm_4x16__avx_broadcast)
697   BENCHMARK_GEMM(f32_gemm_5x16__avx_broadcast)
698 
699   BENCHMARK_GEMM(f32_gemm_1x8__fma3_broadcast)
700   BENCHMARK_GEMM(f32_gemm_4x8__fma3_broadcast)
701   BENCHMARK_GEMM(f32_gemm_5x8__fma3_broadcast)
702   BENCHMARK_GEMM(f32_gemm_6x8__fma3_broadcast)
703   BENCHMARK_GEMM(f32_gemm_7x8__fma3_broadcast)
704   BENCHMARK_GEMM(f32_gemm_8x8__fma3_broadcast)
705   BENCHMARK_GEMM(f32_gemm_1x16__fma3_broadcast)
706   BENCHMARK_GEMM(f32_gemm_3x16__fma3_broadcast)
707   BENCHMARK_GEMM(f32_gemm_4x16__fma3_broadcast)
708   BENCHMARK_GEMM(f32_gemm_5x16__fma3_broadcast)
709 
710   BENCHMARK_GEMM(f32_gemm_1x16s4__fma3_broadcast)
711   BENCHMARK_GEMM(f32_gemm_3x16s4__fma3_broadcast)
712   BENCHMARK_GEMM(f32_gemm_4x16s4__fma3_broadcast)
713   BENCHMARK_GEMM(f32_gemm_5x16s4__fma3_broadcast)
714 
715   BENCHMARK_GEMM(f32_gemm_1x16__avx512f_broadcast)
716   BENCHMARK_GEMM(f32_gemm_4x16__avx512f_broadcast)
717   BENCHMARK_GEMM(f32_gemm_5x16__avx512f_broadcast)
718   BENCHMARK_GEMM(f32_gemm_6x16__avx512f_broadcast)
719   BENCHMARK_GEMM(f32_gemm_7x16__avx512f_broadcast)
720   BENCHMARK_GEMM(f32_gemm_8x16__avx512f_broadcast)
721 #endif  // XNN_ARCH_X86 || XNN_ARCH_X86_64
722 
723 #if !XNN_ARCH_WASM && !XNN_ARCH_ASMJS
724   static void f32_gemm_4x8__psimd_loadsplat(benchmark::State& state, const char* net) {
725     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__psimd_loadsplat, 4, 8, 1, 1);
726   }
727 
f32_gemm_6x8__psimd_loadsplat(benchmark::State & state,const char * net)728   static void f32_gemm_6x8__psimd_loadsplat(benchmark::State& state, const char* net) {
729     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__psimd_loadsplat, 6, 8, 1, 1);
730   }
731 
f32_gemm_4x8__psimd_splat(benchmark::State & state,const char * net)732   static void f32_gemm_4x8__psimd_splat(benchmark::State& state, const char* net) {
733     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__psimd_splat, 4, 8, 1, 1);
734   }
735 
f32_gemm_6x8__psimd_splat(benchmark::State & state,const char * net)736   static void f32_gemm_6x8__psimd_splat(benchmark::State& state, const char* net) {
737     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__psimd_splat, 6, 8, 1, 1);
738   }
739 
f32_gemm_4x8s4__psimd(benchmark::State & state,const char * net)740   static void f32_gemm_4x8s4__psimd(benchmark::State& state, const char* net) {
741     GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__psimd, 4, 8, 1, 4);
742   }
743 
f32_gemm_6x8s4__psimd(benchmark::State & state,const char * net)744   static void f32_gemm_6x8s4__psimd(benchmark::State& state, const char* net) {
745     GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__psimd, 6, 8, 1, 4);
746   }
747 
f32_ppmm_4x8_unipass__psimd(benchmark::State & state,const char * net)748   static void f32_ppmm_4x8_unipass__psimd(benchmark::State& state, const char* net) {
749     PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__psimd, xnn_x32_packx_ukernel_4x__psimd, 4, 8);
750   }
751 
f32_ppmm_4x8_twopass__psimd(benchmark::State & state,const char * net)752   static void f32_ppmm_4x8_twopass__psimd(benchmark::State& state, const char* net) {
753     PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__psimd, xnn_x32_packx_ukernel_4x__psimd, 4, 8);
754   }
755 
756   BENCHMARK_GEMM(f32_gemm_4x8__psimd_loadsplat)
BENCHMARK_GEMM(f32_gemm_6x8__psimd_loadsplat)757   BENCHMARK_GEMM(f32_gemm_6x8__psimd_loadsplat)
758   BENCHMARK_GEMM(f32_gemm_4x8__psimd_splat)
759   BENCHMARK_GEMM(f32_gemm_6x8__psimd_splat)
760   BENCHMARK_GEMM(f32_gemm_4x8s4__psimd)
761   BENCHMARK_GEMM(f32_gemm_6x8s4__psimd)
762   BENCHMARK_GEMM(f32_ppmm_4x8_unipass__psimd)
763   BENCHMARK_GEMM(f32_ppmm_4x8_twopass__psimd)
764 #endif  // !XNN_ARCH_WASM && !XNN_ARCH_ASMJS
765 
766 static void f32_gemm_1x4__scalar(benchmark::State& state, const char* net) {
767   GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x4__scalar, 1, 4, 1, 1);
768 }
769 
f32_gemm_2x4__scalar(benchmark::State & state,const char * net)770 static void f32_gemm_2x4__scalar(benchmark::State& state, const char* net) {
771   GEMMBenchmark(state, xnn_f32_gemm_ukernel_2x4__scalar, 2, 4, 1, 1);
772 }
773 
f32_gemm_4x4__scalar(benchmark::State & state,const char * net)774 static void f32_gemm_4x4__scalar(benchmark::State& state, const char* net) {
775   GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x4__scalar, 4, 4, 1, 1);
776 }
777 
f32_ppmm_2x4_unipass__scalar(benchmark::State & state,const char * net)778 static void f32_ppmm_2x4_unipass__scalar(benchmark::State& state, const char* net) {
779   PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_2x4__scalar, xnn_x32_packx_ukernel_2x__scalar, 2, 4);
780 }
781 
f32_ppmm_4x2_unipass__scalar(benchmark::State & state,const char * net)782 static void f32_ppmm_4x2_unipass__scalar(benchmark::State& state, const char* net) {
783   PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x2__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 2);
784 }
785 
f32_ppmm_4x4_unipass__scalar(benchmark::State & state,const char * net)786 static void f32_ppmm_4x4_unipass__scalar(benchmark::State& state, const char* net) {
787   PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x4__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 4);
788 }
789 
f32_ppmm_3x3_unipass__scalar(benchmark::State & state,const char * net)790 static void f32_ppmm_3x3_unipass__scalar(benchmark::State& state, const char* net) {
791   PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_3x3__scalar, xnn_x32_packx_ukernel_3x__scalar, 3, 3);
792 }
793 
f32_ppmm_2x4_twopass__scalar(benchmark::State & state,const char * net)794 static void f32_ppmm_2x4_twopass__scalar(benchmark::State& state, const char* net) {
795   PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_2x4__scalar, xnn_x32_packx_ukernel_2x__scalar, 2, 4);
796 }
797 
f32_ppmm_4x2_twopass__scalar(benchmark::State & state,const char * net)798 static void f32_ppmm_4x2_twopass__scalar(benchmark::State& state, const char* net) {
799   PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x2__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 2);
800 }
801 
f32_ppmm_4x4_twopass__scalar(benchmark::State & state,const char * net)802 static void f32_ppmm_4x4_twopass__scalar(benchmark::State& state, const char* net) {
803   PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x4__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 4);
804 }
805 
f32_ppmm_3x3_twopass__scalar(benchmark::State & state,const char * net)806 static void f32_ppmm_3x3_twopass__scalar(benchmark::State& state, const char* net) {
807   PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_3x3__scalar, xnn_x32_packx_ukernel_3x__scalar, 3, 3);
808 }
809 
810 BENCHMARK_GEMM(f32_gemm_1x4__scalar)
811 BENCHMARK_GEMM(f32_gemm_2x4__scalar)
812 BENCHMARK_GEMM(f32_gemm_4x4__scalar)
813 
814 BENCHMARK_GEMM(f32_ppmm_2x4_unipass__scalar)
815 BENCHMARK_GEMM(f32_ppmm_4x2_unipass__scalar)
816 BENCHMARK_GEMM(f32_ppmm_4x4_unipass__scalar)
817 BENCHMARK_GEMM(f32_ppmm_3x3_unipass__scalar)
818 
819 BENCHMARK_GEMM(f32_ppmm_2x4_twopass__scalar)
820 BENCHMARK_GEMM(f32_ppmm_4x2_twopass__scalar)
821 BENCHMARK_GEMM(f32_ppmm_4x4_twopass__scalar)
822 BENCHMARK_GEMM(f32_ppmm_3x3_twopass__scalar)
823 
824 #ifdef BENCHMARK_RUY
825 BENCHMARK_GEMM(ruy_st)
826 #endif  // BENCHMARK_RUY
827 
828 #ifndef XNNPACK_BENCHMARK_NO_MAIN
829 BENCHMARK_MAIN();
830 #endif
831