1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <algorithm>
10 #include <cfloat>
11 #include <chrono>
12 #include <cmath>
13 #include <functional>
14 #include <mutex>
15 #include <random>
16 #include <vector>
17
18 #include <cpuinfo.h>
19
20 #include <benchmark/benchmark.h>
21 #ifdef BENCHMARK_RUY
22 #include "tensorflow/lite/experimental/ruy/ruy.h"
23 #endif // BENCHMARK_RUY
24 #include "bench/gemm.h"
25 #include "bench/utils.h"
26 #include <xnnpack/AlignedAllocator.h>
27 #include <xnnpack/common.h>
28 #include <xnnpack/gemm.h>
29 #include <xnnpack/pack.h>
30 #include <xnnpack/packx.h>
31 #include <xnnpack/params-init.h>
32 #include <xnnpack/params.h>
33 #include <xnnpack/ppmm.h>
34
35
GEMMBenchmark(benchmark::State & state,xnn_f32_gemm_ukernel_function gemm,size_t mr,size_t nr,size_t kr,size_t sr,benchmark::utils::IsaCheckFunction isa_check=nullptr)36 static void GEMMBenchmark(benchmark::State& state,
37 xnn_f32_gemm_ukernel_function gemm,
38 size_t mr, size_t nr, size_t kr, size_t sr,
39 benchmark::utils::IsaCheckFunction isa_check = nullptr)
40 {
41 if (!cpuinfo_initialize()) {
42 state.SkipWithError("cpuinfo initialization failed");
43 return;
44 }
45 if (isa_check && !isa_check(state)) {
46 return;
47 }
48
49 const size_t mc = state.range(0);
50 const size_t nc = state.range(1);
51 const size_t kc = state.range(2);
52
53 const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
54 const size_t kc_stride = benchmark::utils::RoundUp(kc, kr);
55
56 std::random_device random_device;
57 auto rng = std::mt19937(random_device());
58 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
59
60 std::vector<float> a(mc * kc);
61 std::generate(a.begin(), a.end(), std::ref(f32rng));
62 std::vector<float> k(nc * kc);
63 std::generate(k.begin(), k.end(), std::ref(f32rng));
64 std::vector<float> b(nc);
65 std::generate(b.begin(), b.end(), std::ref(f32rng));
66
67 const size_t w_elements = nc_stride * kc_stride + nc_stride;
68 const size_t c_elements = mc * nc;
69 const size_t num_buffers = 1 +
70 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
71 sizeof(float) * (w_elements + c_elements));
72
73 std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
74 std::fill(w.begin(), w.end(), 0.0f);
75 xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data());
76 std::vector<float> c(c_elements * num_buffers);
77 std::fill(c.begin(), c.end(), std::nanf(""));
78
79 xnn_f32_output_params output_params =
80 xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
81
82 size_t buffer_index = 0;
83 for (auto _ : state) {
84 // Use circular buffers (exceeding cache size) and prefetch to control cache state:
85 // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
86 // - W is not in cache (for any cache level)
87 // - C is not in cache (for any cache level)
88 state.PauseTiming();
89 benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
90 buffer_index = (buffer_index + 1) % num_buffers;
91 state.ResumeTiming();
92
93 for (uint32_t m = 0; m < mc; m += mr) {
94 const uint32_t mb = min(mc - m, mr);
95 gemm(
96 mb, nc, kc * sizeof(float),
97 a.data() + m * kc, kc * sizeof(float),
98 w.data() + buffer_index * nc_stride * (kc_stride + 1),
99 c.data() + (buffer_index * mc + m) * nc, nc * sizeof(float), nr * sizeof(float),
100 &output_params);
101 }
102 }
103
104 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
105 state.counters["FLOPS"] = benchmark::Counter(
106 uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
107 }
108
PPMM1PBenchmark(benchmark::State & state,xnn_f32_ppmm_ukernel_function ppmm,xnn_x32_packx_ukernel_function packx,size_t mr,size_t nr,benchmark::utils::IsaCheckFunction isa_check=nullptr)109 static void PPMM1PBenchmark(benchmark::State& state,
110 xnn_f32_ppmm_ukernel_function ppmm,
111 xnn_x32_packx_ukernel_function packx,
112 size_t mr, size_t nr,
113 benchmark::utils::IsaCheckFunction isa_check = nullptr)
114 {
115 if (!cpuinfo_initialize()) {
116 state.SkipWithError("cpuinfo initialization failed");
117 return;
118 }
119 if (isa_check && !isa_check(state)) {
120 return;
121 }
122
123 const size_t mc = state.range(0);
124 const size_t nc = state.range(1);
125 const size_t kc = state.range(2);
126
127 const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
128
129 std::random_device random_device;
130 auto rng = std::mt19937(random_device());
131 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
132
133 std::vector<float> a(mc * kc);
134 std::generate(a.begin(), a.end(), std::ref(f32rng));
135 std::vector<float> k(nc * kc);
136 std::generate(k.begin(), k.end(), std::ref(f32rng));
137 std::vector<float> b(nc);
138 std::generate(b.begin(), b.end(), std::ref(f32rng));
139
140 std::vector<uint32_t, AlignedAllocator<uint32_t, 32>> t(mr * kc);
141
142 const size_t w_elements = nc_stride * kc + nc_stride;
143 const size_t c_elements = mc * nc;
144 const size_t num_buffers = 1 +
145 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
146 sizeof(float) * (w_elements + c_elements));
147
148 std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
149 std::fill(w.begin(), w.end(), 0.0f);
150 xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, 1 /* kr */, 1 /* sr */, k.data(), b.data(), w.data());
151 std::vector<float> c(c_elements * num_buffers);
152 std::fill(c.begin(), c.end(), std::nanf(""));
153
154 xnn_f32_output_params output_params =
155 xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
156
157 size_t buffer_index = 0;
158 for (auto _ : state) {
159 // Use circular buffers (exceeding cache size) and prefetch to control cache state:
160 // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
161 // - W is not in cache (for any cache level)
162 // - C is not in cache (for any cache level)
163 state.PauseTiming();
164 benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
165 buffer_index = (buffer_index + 1) % num_buffers;
166 state.ResumeTiming();
167
168 for (uint32_t m = 0; m < mc; m += mr) {
169 const uint32_t mb = min(mc - m, mr);
170 packx(mb, kc, reinterpret_cast<const uint32_t*>(a.data() + m * kc), kc, t.data());
171 ppmm(
172 mb, nc, kc * sizeof(float),
173 reinterpret_cast<const float*>(t.data()),
174 w.data() + nc_stride * buffer_index * (kc + 1),
175 c.data() + (mc * buffer_index + m) * nc, nc * sizeof(float), nr * sizeof(float),
176 &output_params);
177 }
178 }
179
180 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
181 state.counters["FLOPS"] = benchmark::Counter(
182 uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
183 }
184
PPMM2PBenchmark(benchmark::State & state,xnn_f32_ppmm_ukernel_function ppmm,xnn_x32_packx_ukernel_function packx,size_t mr,size_t nr,benchmark::utils::IsaCheckFunction isa_check=nullptr)185 static void PPMM2PBenchmark(benchmark::State& state,
186 xnn_f32_ppmm_ukernel_function ppmm,
187 xnn_x32_packx_ukernel_function packx,
188 size_t mr, size_t nr,
189 benchmark::utils::IsaCheckFunction isa_check = nullptr)
190 {
191 if (!cpuinfo_initialize()) {
192 state.SkipWithError("cpuinfo initialization failed");
193 return;
194 }
195 if (isa_check && !isa_check(state)) {
196 return;
197 }
198
199 const size_t mc = state.range(0);
200 const size_t nc = state.range(1);
201 const size_t kc = state.range(2);
202
203 const size_t mc_stride = benchmark::utils::RoundUp(mc, mr);
204 const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
205
206 std::random_device random_device;
207 auto rng = std::mt19937(random_device());
208 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
209
210 std::vector<float> a(mc * kc);
211 std::generate(a.begin(), a.end(), std::ref(f32rng));
212 std::vector<float> k(nc * kc);
213 std::generate(k.begin(), k.end(), std::ref(f32rng));
214 std::vector<float> b(nc);
215 std::generate(b.begin(), b.end(), std::ref(f32rng));
216
217 std::vector<uint32_t, AlignedAllocator<uint32_t, 32>> t(mc_stride * kc);
218
219 const size_t w_elements = nc_stride * kc + nc_stride;
220 const size_t c_elements = mc * nc;
221 const size_t num_buffers = 1 +
222 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
223 sizeof(float) * (w_elements + c_elements));
224
225 std::vector<float, AlignedAllocator<float, 32>> w(w_elements * num_buffers);
226 std::fill(w.begin(), w.end(), 0.0f);
227 xnn_pack_f32_gemm_goi_w(1 /* groups */, nc, kc, nr, 1 /* kr */, 1 /* sr */, k.data(), b.data(), w.data());
228 std::vector<float> c(c_elements * num_buffers);
229 std::fill(c.begin(), c.end(), std::nanf(""));
230
231 xnn_f32_output_params output_params =
232 xnn_init_f32_output_params(-std::numeric_limits<float>::infinity(), +std::numeric_limits<float>::infinity());
233
234 size_t buffer_index = 0;
235 for (auto _ : state) {
236 // Use circular buffers (exceeding cache size) and prefetch to control cache state:
237 // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
238 // - W is not in cache (for any cache level)
239 // - C is not in cache (for any cache level)
240 state.PauseTiming();
241 benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
242 buffer_index = (buffer_index + 1) % num_buffers;
243 state.ResumeTiming();
244
245 for (uint32_t m = 0; m < mc; m += mr) {
246 const uint32_t mb = min(mc - m, mr);
247 packx(mb, kc, reinterpret_cast<const uint32_t*>(a.data() + m * kc), kc, t.data() + m * kc);
248 }
249 for (uint32_t m = 0; m < mc; m += mr) {
250 const uint32_t mb = min(mc - m, mr);
251 ppmm(
252 mb, nc, kc * sizeof(float),
253 reinterpret_cast<const float*>(t.data() + m * kc),
254 w.data() + nc_stride * buffer_index * (kc + 1),
255 c.data() + (mc * buffer_index + m) * nc, nc * sizeof(float), nr * sizeof(float),
256 &output_params);
257 }
258 }
259
260 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
261 state.counters["FLOPS"] = benchmark::Counter(
262 uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
263 }
264
265 #ifdef BENCHMARK_RUY
RuyBenchmark(benchmark::State & state,uint32_t threads)266 static void RuyBenchmark(benchmark::State& state, uint32_t threads)
267 {
268 std::random_device random_device;
269 auto rng = std::mt19937(random_device());
270 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
271
272 const size_t mc = state.range(0);
273 const size_t nc = state.range(1);
274 const size_t kc = state.range(2);
275
276 const size_t num_buffers = 1 +
277 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
278 sizeof(float) * (nc * (mc + kc + 1)));
279
280 std::vector<float> a(mc * kc);
281 std::generate(a.begin(), a.end(), std::ref(f32rng));
282 std::vector<float> k(num_buffers * nc * kc);
283 std::generate(k.begin(), k.end(), std::ref(f32rng));
284 std::vector<float> b(num_buffers * nc);
285 std::generate(b.begin(), b.end(), std::ref(f32rng));
286 std::vector<float> c(num_buffers * nc * mc);
287 std::fill(c.begin(), c.end(), std::nanf(""));
288
289 // Note: context must be static to avoid the cost of re-creating it for each benchmark.
290 static ruy::Context context;
291 context.max_num_threads = threads;
292
293 ruy::Matrix<float> ruy_a;
294 ruy::MakeSimpleLayout(nc, kc, ruy::Order::kRowMajor, &ruy_a.layout);
295 ruy::Matrix<float> ruy_b;
296 ruy::MakeSimpleLayout(kc, mc, ruy::Order::kColMajor, &ruy_b.layout);
297 ruy_b.data = a.data();
298 ruy::Matrix<float> ruy_c;
299 ruy::MakeSimpleLayout(nc, mc, ruy::Order::kColMajor, &ruy_c.layout);
300
301 ruy::BasicSpec<float, float> spec;
302
303 // ruy::Context uses deferred initialization, which affects percieved GEMM performance. Initialization happens during
304 // the first GEMM calls, and per Benoit Jacob it takes up to ~250 milliseconds for performance to stabilize.
305 // Thus, on the first benchmark, we compute GEMM for 500 milliseconds (to be safe) without recording performance, and
306 // keep the ruy::Context object initialized (by being static) between subsequent benchmarks.
307 static std::once_flag warmup;
308 std::call_once(warmup, [&](){
309 auto start = std::chrono::steady_clock::now();
310 do {
311 ruy_a.data = k.data();
312 ruy_c.data = c.data();
313 spec.bias = b.data();
314
315 ruy::Mul<ruy::kAllPaths>(ruy_a, ruy_b, spec, &context, &ruy_c);
316 } while (std::chrono::duration<double>(std::chrono::steady_clock::now() - start).count() < 0.5);
317 });
318
319 size_t buffer_index = 0;
320 for (auto _ : state) {
321 // Use circular buffers (exceeding cache size) and prefetch to control cache state:
322 // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
323 // - K is not in cache (for any cache level)
324 // - B is not in cache (for any cache level)
325 // - C is not in cache (for any cache level)
326 state.PauseTiming();
327 benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(float));
328 buffer_index = (buffer_index + 1) % num_buffers;
329 state.ResumeTiming();
330
331 ruy_a.data = k.data() + buffer_index * nc * kc;
332 ruy_c.data = c.data() + buffer_index * mc * nc;
333 spec.bias = b.data() + buffer_index * nc;
334
335 ruy::Mul<ruy::kAllPaths>(ruy_a, ruy_b, spec, &context, &ruy_c);
336 }
337
338 state.counters["Freq"] = benchmark::utils::GetCurrentCpuFrequency();
339 state.counters["FLOPS"] = benchmark::Counter(
340 uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
341 }
342
ruy_st(benchmark::State & state,const char * net)343 static void ruy_st(benchmark::State& state, const char* net)
344 {
345 RuyBenchmark(state, 1);
346 }
347 #endif // BENCHMARK_RUY
348
349
350 #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
f32_gemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)351 static void f32_gemm_1x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
352 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x12__aarch64_neonfma_cortex_a53, 1, 12, 1, 1);
353 }
f32_gemm_1x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)354 static void f32_gemm_1x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
355 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a53, 1, 8, 1, 1);
356 }
f32_gemm_1x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)357 static void f32_gemm_1x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
358 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a57, 1, 8, 1, 1);
359 }
f32_gemm_1x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)360 static void f32_gemm_1x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
361 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75, 1, 8, 1, 1);
362 }
f32_gemm_4x12__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)363 static void f32_gemm_4x12__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
364 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x12__aarch64_neonfma_cortex_a53, 4, 12, 1, 1);
365 }
f32_gemm_4x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)366 static void f32_gemm_4x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
367 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a53, 4, 8, 1, 1);
368 }
f32_gemm_4x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)369 static void f32_gemm_4x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
370 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a57, 4, 8, 1, 1);
371 }
f32_gemm_4x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)372 static void f32_gemm_4x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
373 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75, 4, 8, 1, 1);
374 }
f32_gemm_4x8__aarch64_neonfma_ld64(benchmark::State & state,const char * net)375 static void f32_gemm_4x8__aarch64_neonfma_ld64(benchmark::State& state, const char* net) {
376 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld64, 4, 8, 1, 1);
377 }
f32_gemm_4x8__aarch64_neonfma_ld128(benchmark::State & state,const char * net)378 static void f32_gemm_4x8__aarch64_neonfma_ld128(benchmark::State& state, const char* net) {
379 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch64_neonfma_ld128, 4, 8, 1, 1);
380 }
f32_gemm_5x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)381 static void f32_gemm_5x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
382 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a57, 5, 8, 1, 1);
383 }
f32_gemm_5x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)384 static void f32_gemm_5x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
385 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__aarch64_neonfma_cortex_a75, 5, 8, 1, 1);
386 }
f32_gemm_6x8__aarch64_neonfma_ld64(benchmark::State & state,const char * net)387 static void f32_gemm_6x8__aarch64_neonfma_ld64(benchmark::State& state, const char* net) {
388 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld64, 6, 8, 1, 1);
389 }
f32_gemm_6x8__aarch64_neonfma_ld128(benchmark::State & state,const char * net)390 static void f32_gemm_6x8__aarch64_neonfma_ld128(benchmark::State& state, const char* net) {
391 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_ld128, 6, 8, 1, 1);
392 }
f32_gemm_6x8__aarch64_neonfma_cortex_a53(benchmark::State & state,const char * net)393 static void f32_gemm_6x8__aarch64_neonfma_cortex_a53(benchmark::State& state, const char* net) {
394 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a53, 6, 8, 1, 1);
395 }
f32_gemm_6x8__aarch64_neonfma_cortex_a73(benchmark::State & state,const char * net)396 static void f32_gemm_6x8__aarch64_neonfma_cortex_a73(benchmark::State& state, const char* net) {
397 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a73, 6, 8, 1, 1);
398 }
f32_gemm_6x8__aarch64_neonfma_cortex_a57(benchmark::State & state,const char * net)399 static void f32_gemm_6x8__aarch64_neonfma_cortex_a57(benchmark::State& state, const char* net) {
400 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a57, 6, 8, 1, 1);
401 }
f32_gemm_6x8__aarch64_neonfma_cortex_a75(benchmark::State & state,const char * net)402 static void f32_gemm_6x8__aarch64_neonfma_cortex_a75(benchmark::State& state, const char* net) {
403 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75, 6, 8, 1, 1);
404 }
f32_gemm_1x8__neonfma_lane_ld64(benchmark::State & state,const char * net)405 static void f32_gemm_1x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
406 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neonfma_lane_ld64, 1, 8, 1, 1);
407 }
f32_gemm_4x8__neonfma_lane_ld64(benchmark::State & state,const char * net)408 static void f32_gemm_4x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
409 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_lane_ld64, 4, 8, 1, 1);
410 }
f32_gemm_4x8__neonfma_lane_ld128(benchmark::State & state,const char * net)411 static void f32_gemm_4x8__neonfma_lane_ld128(benchmark::State& state, const char* net) {
412 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_lane_ld128, 4, 8, 1, 1);
413 }
f32_gemm_5x8__neonfma_lane_ld64(benchmark::State & state,const char * net)414 static void f32_gemm_5x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
415 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__neonfma_lane_ld64, 5, 8, 1, 1);
416 }
f32_gemm_6x8__neonfma_lane_ld64(benchmark::State & state,const char * net)417 static void f32_gemm_6x8__neonfma_lane_ld64(benchmark::State& state, const char* net) {
418 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld64, 6, 8, 1, 1);
419 }
f32_gemm_6x8__neonfma_lane_ld128(benchmark::State & state,const char * net)420 static void f32_gemm_6x8__neonfma_lane_ld128(benchmark::State& state, const char* net) {
421 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_lane_ld128, 6, 8, 1, 1);
422 }
423 BENCHMARK_GEMM(f32_gemm_1x12__aarch64_neonfma_cortex_a53)
BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a53)424 BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a53)
425 BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a57)
426 BENCHMARK_GEMM(f32_gemm_1x8__aarch64_neonfma_cortex_a75)
427 BENCHMARK_GEMM(f32_gemm_4x12__aarch64_neonfma_cortex_a53)
428 BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a53)
429 BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a57)
430 BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_cortex_a75)
431 BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_ld128)
432 BENCHMARK_GEMM(f32_gemm_4x8__aarch64_neonfma_ld64)
433 BENCHMARK_GEMM(f32_gemm_5x8__aarch64_neonfma_cortex_a57)
434 BENCHMARK_GEMM(f32_gemm_5x8__aarch64_neonfma_cortex_a75)
435 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a53)
436 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a73)
437 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a57)
438 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_cortex_a75)
439 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_ld64)
440 BENCHMARK_GEMM(f32_gemm_6x8__aarch64_neonfma_ld128)
441 BENCHMARK_GEMM(f32_gemm_1x8__neonfma_lane_ld64)
442 BENCHMARK_GEMM(f32_gemm_4x8__neonfma_lane_ld64)
443 BENCHMARK_GEMM(f32_gemm_4x8__neonfma_lane_ld128)
444 BENCHMARK_GEMM(f32_gemm_5x8__neonfma_lane_ld64)
445 BENCHMARK_GEMM(f32_gemm_6x8__neonfma_lane_ld64)
446 BENCHMARK_GEMM(f32_gemm_6x8__neonfma_lane_ld128)
447 #endif // XNN_ARCH_ARM64
448
449 #if XNN_ARCH_ARM && XNN_ENABLE_ASSEMBLY
450 static void f32_gemm_4x8__aarch32_neon_ld64(benchmark::State& state, const char* net) {
451 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEON);
452 }
f32_gemm_4x8__aarch32_neon_cortex_a53(benchmark::State & state,const char * net)453 static void f32_gemm_4x8__aarch32_neon_cortex_a53(benchmark::State& state, const char* net) {
454 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a53, 4, 8, 1, 1, benchmark::utils::CheckNEON);
455 }
f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State & state,const char * net)456 static void f32_gemm_4x8__aarch32_neon_cortex_a75(benchmark::State& state, const char* net) {
457 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75, 4, 8, 1, 1, benchmark::utils::CheckNEON);
458 }
f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State & state,const char * net)459 static void f32_gemm_4x8__aarch32_neon_pld_cortex_a75(benchmark::State& state, const char* net) {
460 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__aarch32_neon_pld_cortex_a75, 4, 8, 1, 1, benchmark::utils::CheckNEON);
461 }
462
463 BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_ld64)
BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a53)464 BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a53)
465 BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_cortex_a75)
466 BENCHMARK_GEMM(f32_gemm_4x8__aarch32_neon_pld_cortex_a75)
467 #endif // XNN_ARCH_ARM
468
469 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
470 static void f32_gemm_1x8__neon_lane_ld64(benchmark::State& state, const char* net) {
471 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neon_lane_ld64, 1, 8, 1, 1, benchmark::utils::CheckNEON);
472 }
f32_gemm_4x8__neon_lane_ld64(benchmark::State & state,const char * net)473 static void f32_gemm_4x8__neon_lane_ld64(benchmark::State& state, const char* net) {
474 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neon_lane_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEON);
475 }
f32_gemm_4x8__neon_lane_ld128(benchmark::State & state,const char * net)476 static void f32_gemm_4x8__neon_lane_ld128(benchmark::State& state, const char* net) {
477 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neon_lane_ld128, 4, 8, 1, 1, benchmark::utils::CheckNEON);
478 }
f32_gemm_5x8__neon_lane_ld64(benchmark::State & state,const char * net)479 static void f32_gemm_5x8__neon_lane_ld64(benchmark::State& state, const char* net) {
480 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__neon_lane_ld64, 5, 8, 1, 1, benchmark::utils::CheckNEON);
481 }
f32_gemm_6x8__neon_lane_ld64(benchmark::State & state,const char * net)482 static void f32_gemm_6x8__neon_lane_ld64(benchmark::State& state, const char* net) {
483 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neon_lane_ld64, 6, 8, 1, 1, benchmark::utils::CheckNEON);
484 }
f32_gemm_6x8__neon_lane_ld128(benchmark::State & state,const char * net)485 static void f32_gemm_6x8__neon_lane_ld128(benchmark::State& state, const char* net) {
486 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neon_lane_ld128, 6, 8, 1, 1, benchmark::utils::CheckNEON);
487 }
f32_gemm_1x8__neonfma_dup_ld64(benchmark::State & state,const char * net)488 static void f32_gemm_1x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
489 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__neonfma_dup_ld64, 1, 8, 1, 1, benchmark::utils::CheckNEONFMA);
490 }
f32_gemm_4x8__neonfma_dup_ld64(benchmark::State & state,const char * net)491 static void f32_gemm_4x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
492 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_dup_ld64, 4, 8, 1, 1, benchmark::utils::CheckNEONFMA);
493 }
f32_gemm_4x8__neonfma_dup_ld128(benchmark::State & state,const char * net)494 static void f32_gemm_4x8__neonfma_dup_ld128(benchmark::State& state, const char* net) {
495 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__neonfma_dup_ld128, 4, 8, 1, 1, benchmark::utils::CheckNEONFMA);
496 }
f32_gemm_6x8__neonfma_dup_ld64(benchmark::State & state,const char * net)497 static void f32_gemm_6x8__neonfma_dup_ld64(benchmark::State& state, const char* net) {
498 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_dup_ld64, 6, 8, 1, 1, benchmark::utils::CheckNEONFMA);
499 }
f32_gemm_6x8__neonfma_dup_ld128(benchmark::State & state,const char * net)500 static void f32_gemm_6x8__neonfma_dup_ld128(benchmark::State& state, const char* net) {
501 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__neonfma_dup_ld128, 6, 8, 1, 1, benchmark::utils::CheckNEONFMA);
502 }
f32_gemm_1x8s4__neon(benchmark::State & state,const char * net)503 static void f32_gemm_1x8s4__neon(benchmark::State& state, const char* net) {
504 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__neon, 1, 8, 1, 4, benchmark::utils::CheckNEON);
505 }
f32_gemm_1x8s4__neonfma(benchmark::State & state,const char * net)506 static void f32_gemm_1x8s4__neonfma(benchmark::State& state, const char* net) {
507 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__neonfma, 1, 8, 1, 4, benchmark::utils::CheckNEONFMA);
508 }
f32_gemm_4x8s4__neon(benchmark::State & state,const char * net)509 static void f32_gemm_4x8s4__neon(benchmark::State& state, const char* net) {
510 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__neon, 4, 8, 1, 4, benchmark::utils::CheckNEON);
511 }
f32_gemm_4x8s4__neonfma(benchmark::State & state,const char * net)512 static void f32_gemm_4x8s4__neonfma(benchmark::State& state, const char* net) {
513 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__neonfma, 4, 8, 1, 4, benchmark::utils::CheckNEONFMA);
514 }
f32_gemm_6x8s4__neon(benchmark::State & state,const char * net)515 static void f32_gemm_6x8s4__neon(benchmark::State& state, const char* net) {
516 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__neon, 6, 8, 1, 4, benchmark::utils::CheckNEON);
517 }
f32_gemm_6x8s4__neonfma(benchmark::State & state,const char * net)518 static void f32_gemm_6x8s4__neonfma(benchmark::State& state, const char* net) {
519 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__neonfma, 6, 8, 1, 4, benchmark::utils::CheckNEONFMA);
520 }
f32_gemm_8x8s4__neon(benchmark::State & state,const char * net)521 static void f32_gemm_8x8s4__neon(benchmark::State& state, const char* net) {
522 GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8s4__neon, 8, 8, 1, 4, benchmark::utils::CheckNEON);
523 }
f32_gemm_8x8s4__neonfma(benchmark::State & state,const char * net)524 static void f32_gemm_8x8s4__neonfma(benchmark::State& state, const char* net) {
525 GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8s4__neonfma, 8, 8, 1, 4, benchmark::utils::CheckNEONFMA);
526 }
f32_ppmm_4x8_unipass__neonfma(benchmark::State & state,const char * net)527 static void f32_ppmm_4x8_unipass__neonfma(benchmark::State& state, const char* net) {
528 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__neonfma, xnn_x32_packx_ukernel_4x__neon_st4, 4, 8, benchmark::utils::CheckNEONFMA);
529 }
f32_ppmm_4x8_twopass__neonfma(benchmark::State & state,const char * net)530 static void f32_ppmm_4x8_twopass__neonfma(benchmark::State& state, const char* net) {
531 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__neonfma, xnn_x32_packx_ukernel_4x__neon_st4, 4, 8, benchmark::utils::CheckNEONFMA);
532 }
533
534 BENCHMARK_GEMM(f32_gemm_1x8__neon_lane_ld64)
BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld64)535 BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld64)
536 BENCHMARK_GEMM(f32_gemm_4x8__neon_lane_ld128)
537 BENCHMARK_GEMM(f32_gemm_5x8__neon_lane_ld64)
538 BENCHMARK_GEMM(f32_gemm_6x8__neon_lane_ld64)
539 BENCHMARK_GEMM(f32_gemm_6x8__neon_lane_ld128)
540 BENCHMARK_GEMM(f32_gemm_1x8__neonfma_dup_ld64)
541 BENCHMARK_GEMM(f32_gemm_4x8__neonfma_dup_ld128)
542 BENCHMARK_GEMM(f32_gemm_4x8__neonfma_dup_ld64)
543 BENCHMARK_GEMM(f32_gemm_6x8__neonfma_dup_ld64)
544 BENCHMARK_GEMM(f32_gemm_6x8__neonfma_dup_ld128)
545 BENCHMARK_GEMM(f32_gemm_1x8s4__neon)
546 BENCHMARK_GEMM(f32_gemm_1x8s4__neonfma)
547 BENCHMARK_GEMM(f32_gemm_4x8s4__neon)
548 BENCHMARK_GEMM(f32_gemm_4x8s4__neonfma)
549 BENCHMARK_GEMM(f32_gemm_6x8s4__neon)
550 BENCHMARK_GEMM(f32_gemm_6x8s4__neonfma)
551 BENCHMARK_GEMM(f32_gemm_8x8s4__neon)
552 BENCHMARK_GEMM(f32_gemm_8x8s4__neonfma)
553 BENCHMARK_GEMM(f32_ppmm_4x8_unipass__neonfma)
554 BENCHMARK_GEMM(f32_ppmm_4x8_twopass__neonfma)
555 #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64
556
557 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
558 static void f32_gemm_1x8__sse_load1(benchmark::State& state, const char* net) {
559 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__sse_load1, 1, 8, 1, 1);
560 }
f32_gemm_4x8__sse_load1(benchmark::State & state,const char * net)561 static void f32_gemm_4x8__sse_load1(benchmark::State& state, const char* net) {
562 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__sse_load1, 4, 8, 1, 1);
563 }
564
f32_gemm_1x8__sse_dup(benchmark::State & state,const char * net)565 static void f32_gemm_1x8__sse_dup(benchmark::State& state, const char* net) {
566 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__sse_dup, 1, 8, 1, 1);
567 }
f32_gemm_4x8__sse_dup(benchmark::State & state,const char * net)568 static void f32_gemm_4x8__sse_dup(benchmark::State& state, const char* net) {
569 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__sse_dup, 4, 8, 1, 1);
570 }
571
f32_gemm_1x8s4__sse(benchmark::State & state,const char * net)572 static void f32_gemm_1x8s4__sse(benchmark::State& state, const char* net) {
573 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8s4__sse, 1, 8, 1, 4);
574 }
f32_gemm_4x8s4__sse(benchmark::State & state,const char * net)575 static void f32_gemm_4x8s4__sse(benchmark::State& state, const char* net) {
576 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__sse, 4, 8, 1, 4);
577 }
578
f32_ppmm_4x8_unipass__sse(benchmark::State & state,const char * net)579 static void f32_ppmm_4x8_unipass__sse(benchmark::State& state, const char* net) {
580 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__sse, xnn_x32_packx_ukernel_4x__sse, 4, 8);
581 }
f32_ppmm_4x8_twopass__sse(benchmark::State & state,const char * net)582 static void f32_ppmm_4x8_twopass__sse(benchmark::State& state, const char* net) {
583 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__sse, xnn_x32_packx_ukernel_4x__sse, 4, 8);
584 }
585
f32_gemm_1x8__avx_broadcast(benchmark::State & state,const char * net)586 static void f32_gemm_1x8__avx_broadcast(benchmark::State& state, const char* net) {
587 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__avx_broadcast, 1, 8, 1, 1, benchmark::utils::CheckAVX);
588 }
f32_gemm_4x8__avx_broadcast(benchmark::State & state,const char * net)589 static void f32_gemm_4x8__avx_broadcast(benchmark::State& state, const char* net) {
590 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__avx_broadcast, 4, 8, 1, 1, benchmark::utils::CheckAVX);
591 }
f32_gemm_5x8__avx_broadcast(benchmark::State & state,const char * net)592 static void f32_gemm_5x8__avx_broadcast(benchmark::State& state, const char* net) {
593 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__avx_broadcast, 5, 8, 1, 1, benchmark::utils::CheckAVX);
594 }
f32_gemm_6x8__avx_broadcast(benchmark::State & state,const char * net)595 static void f32_gemm_6x8__avx_broadcast(benchmark::State& state, const char* net) {
596 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__avx_broadcast, 6, 8, 1, 1, benchmark::utils::CheckAVX);
597 }
f32_gemm_7x8__avx_broadcast(benchmark::State & state,const char * net)598 static void f32_gemm_7x8__avx_broadcast(benchmark::State& state, const char* net) {
599 GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x8__avx_broadcast, 7, 8, 1, 1, benchmark::utils::CheckAVX);
600 }
f32_gemm_1x16__avx_broadcast(benchmark::State & state,const char * net)601 static void f32_gemm_1x16__avx_broadcast(benchmark::State& state, const char* net) {
602 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__avx_broadcast, 1, 16, 1, 1, benchmark::utils::CheckAVX);
603 }
f32_gemm_3x16__avx_broadcast(benchmark::State & state,const char * net)604 static void f32_gemm_3x16__avx_broadcast(benchmark::State& state, const char* net) {
605 GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16__avx_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX);
606 }
f32_gemm_4x16__avx_broadcast(benchmark::State & state,const char * net)607 static void f32_gemm_4x16__avx_broadcast(benchmark::State& state, const char* net) {
608 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__avx_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX);
609 }
f32_gemm_5x16__avx_broadcast(benchmark::State & state,const char * net)610 static void f32_gemm_5x16__avx_broadcast(benchmark::State& state, const char* net) {
611 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__avx_broadcast, 5, 16, 1, 1, benchmark::utils::CheckAVX);
612 }
613
f32_gemm_1x8__fma3_broadcast(benchmark::State & state,const char * net)614 static void f32_gemm_1x8__fma3_broadcast(benchmark::State& state, const char* net) {
615 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x8__fma3_broadcast, 1, 8, 1, 1, benchmark::utils::CheckFMA3);
616 }
f32_gemm_4x8__fma3_broadcast(benchmark::State & state,const char * net)617 static void f32_gemm_4x8__fma3_broadcast(benchmark::State& state, const char* net) {
618 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__fma3_broadcast, 4, 8, 1, 1, benchmark::utils::CheckFMA3);
619 }
f32_gemm_5x8__fma3_broadcast(benchmark::State & state,const char * net)620 static void f32_gemm_5x8__fma3_broadcast(benchmark::State& state, const char* net) {
621 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x8__fma3_broadcast, 5, 8, 1, 1, benchmark::utils::CheckFMA3);
622 }
f32_gemm_6x8__fma3_broadcast(benchmark::State & state,const char * net)623 static void f32_gemm_6x8__fma3_broadcast(benchmark::State& state, const char* net) {
624 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__fma3_broadcast, 6, 8, 1, 1, benchmark::utils::CheckFMA3);
625 }
f32_gemm_7x8__fma3_broadcast(benchmark::State & state,const char * net)626 static void f32_gemm_7x8__fma3_broadcast(benchmark::State& state, const char* net) {
627 GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x8__fma3_broadcast, 7, 8, 1, 1, benchmark::utils::CheckFMA3);
628 }
f32_gemm_8x8__fma3_broadcast(benchmark::State & state,const char * net)629 static void f32_gemm_8x8__fma3_broadcast(benchmark::State& state, const char* net) {
630 GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x8__fma3_broadcast, 8, 8, 1, 1, benchmark::utils::CheckFMA3);
631 }
f32_gemm_1x16__fma3_broadcast(benchmark::State & state,const char * net)632 static void f32_gemm_1x16__fma3_broadcast(benchmark::State& state, const char* net) {
633 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__fma3_broadcast, 1, 16, 1, 1, benchmark::utils::CheckFMA3);
634 }
f32_gemm_3x16__fma3_broadcast(benchmark::State & state,const char * net)635 static void f32_gemm_3x16__fma3_broadcast(benchmark::State& state, const char* net) {
636 GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16__fma3_broadcast, 4, 16, 1, 1, benchmark::utils::CheckFMA3);
637 }
f32_gemm_4x16__fma3_broadcast(benchmark::State & state,const char * net)638 static void f32_gemm_4x16__fma3_broadcast(benchmark::State& state, const char* net) {
639 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__fma3_broadcast, 4, 16, 1, 1, benchmark::utils::CheckFMA3);
640 }
f32_gemm_5x16__fma3_broadcast(benchmark::State & state,const char * net)641 static void f32_gemm_5x16__fma3_broadcast(benchmark::State& state, const char* net) {
642 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__fma3_broadcast, 5, 16, 1, 1, benchmark::utils::CheckFMA3);
643 }
644
f32_gemm_1x16s4__fma3_broadcast(benchmark::State & state,const char * net)645 static void f32_gemm_1x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
646 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16s4__fma3_broadcast, 1, 16, 1, 4, benchmark::utils::CheckFMA3);
647 }
f32_gemm_3x16s4__fma3_broadcast(benchmark::State & state,const char * net)648 static void f32_gemm_3x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
649 GEMMBenchmark(state, xnn_f32_gemm_ukernel_3x16s4__fma3_broadcast, 4, 16, 1, 4, benchmark::utils::CheckFMA3);
650 }
f32_gemm_4x16s4__fma3_broadcast(benchmark::State & state,const char * net)651 static void f32_gemm_4x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
652 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16s4__fma3_broadcast, 4, 16, 1, 4, benchmark::utils::CheckFMA3);
653 }
f32_gemm_5x16s4__fma3_broadcast(benchmark::State & state,const char * net)654 static void f32_gemm_5x16s4__fma3_broadcast(benchmark::State& state, const char* net) {
655 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16s4__fma3_broadcast, 5, 16, 1, 4, benchmark::utils::CheckFMA3);
656 }
657
f32_gemm_1x16__avx512f_broadcast(benchmark::State & state,const char * net)658 static void f32_gemm_1x16__avx512f_broadcast(benchmark::State& state, const char* net) {
659 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x16__avx512f_broadcast, 1, 16, 1, 1, benchmark::utils::CheckAVX512F);
660 }
f32_gemm_4x16__avx512f_broadcast(benchmark::State & state,const char * net)661 static void f32_gemm_4x16__avx512f_broadcast(benchmark::State& state, const char* net) {
662 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x16__avx512f_broadcast, 4, 16, 1, 1, benchmark::utils::CheckAVX512F);
663 }
f32_gemm_5x16__avx512f_broadcast(benchmark::State & state,const char * net)664 static void f32_gemm_5x16__avx512f_broadcast(benchmark::State& state, const char* net) {
665 GEMMBenchmark(state, xnn_f32_gemm_ukernel_5x16__avx512f_broadcast, 5, 16, 1, 1, benchmark::utils::CheckAVX512F);
666 }
f32_gemm_6x16__avx512f_broadcast(benchmark::State & state,const char * net)667 static void f32_gemm_6x16__avx512f_broadcast(benchmark::State& state, const char* net) {
668 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x16__avx512f_broadcast, 6, 16, 1, 1, benchmark::utils::CheckAVX512F);
669 }
f32_gemm_7x16__avx512f_broadcast(benchmark::State & state,const char * net)670 static void f32_gemm_7x16__avx512f_broadcast(benchmark::State& state, const char* net) {
671 GEMMBenchmark(state, xnn_f32_gemm_ukernel_7x16__avx512f_broadcast, 7, 16, 1, 1, benchmark::utils::CheckAVX512F);
672 }
f32_gemm_8x16__avx512f_broadcast(benchmark::State & state,const char * net)673 static void f32_gemm_8x16__avx512f_broadcast(benchmark::State& state, const char* net) {
674 GEMMBenchmark(state, xnn_f32_gemm_ukernel_8x16__avx512f_broadcast, 8, 16, 1, 1, benchmark::utils::CheckAVX512F);
675 }
676
677 BENCHMARK_GEMM(f32_gemm_1x8__sse_load1)
BENCHMARK_GEMM(f32_gemm_4x8__sse_load1)678 BENCHMARK_GEMM(f32_gemm_4x8__sse_load1)
679
680 BENCHMARK_GEMM(f32_gemm_1x8__sse_dup)
681 BENCHMARK_GEMM(f32_gemm_4x8__sse_dup)
682
683 BENCHMARK_GEMM(f32_gemm_1x8s4__sse)
684 BENCHMARK_GEMM(f32_gemm_4x8s4__sse)
685
686 BENCHMARK_GEMM(f32_ppmm_4x8_unipass__sse)
687 BENCHMARK_GEMM(f32_ppmm_4x8_twopass__sse)
688
689 BENCHMARK_GEMM(f32_gemm_1x8__avx_broadcast)
690 BENCHMARK_GEMM(f32_gemm_4x8__avx_broadcast)
691 BENCHMARK_GEMM(f32_gemm_5x8__avx_broadcast)
692 BENCHMARK_GEMM(f32_gemm_6x8__avx_broadcast)
693 BENCHMARK_GEMM(f32_gemm_7x8__avx_broadcast)
694 BENCHMARK_GEMM(f32_gemm_1x16__avx_broadcast)
695 BENCHMARK_GEMM(f32_gemm_3x16__avx_broadcast)
696 BENCHMARK_GEMM(f32_gemm_4x16__avx_broadcast)
697 BENCHMARK_GEMM(f32_gemm_5x16__avx_broadcast)
698
699 BENCHMARK_GEMM(f32_gemm_1x8__fma3_broadcast)
700 BENCHMARK_GEMM(f32_gemm_4x8__fma3_broadcast)
701 BENCHMARK_GEMM(f32_gemm_5x8__fma3_broadcast)
702 BENCHMARK_GEMM(f32_gemm_6x8__fma3_broadcast)
703 BENCHMARK_GEMM(f32_gemm_7x8__fma3_broadcast)
704 BENCHMARK_GEMM(f32_gemm_8x8__fma3_broadcast)
705 BENCHMARK_GEMM(f32_gemm_1x16__fma3_broadcast)
706 BENCHMARK_GEMM(f32_gemm_3x16__fma3_broadcast)
707 BENCHMARK_GEMM(f32_gemm_4x16__fma3_broadcast)
708 BENCHMARK_GEMM(f32_gemm_5x16__fma3_broadcast)
709
710 BENCHMARK_GEMM(f32_gemm_1x16s4__fma3_broadcast)
711 BENCHMARK_GEMM(f32_gemm_3x16s4__fma3_broadcast)
712 BENCHMARK_GEMM(f32_gemm_4x16s4__fma3_broadcast)
713 BENCHMARK_GEMM(f32_gemm_5x16s4__fma3_broadcast)
714
715 BENCHMARK_GEMM(f32_gemm_1x16__avx512f_broadcast)
716 BENCHMARK_GEMM(f32_gemm_4x16__avx512f_broadcast)
717 BENCHMARK_GEMM(f32_gemm_5x16__avx512f_broadcast)
718 BENCHMARK_GEMM(f32_gemm_6x16__avx512f_broadcast)
719 BENCHMARK_GEMM(f32_gemm_7x16__avx512f_broadcast)
720 BENCHMARK_GEMM(f32_gemm_8x16__avx512f_broadcast)
721 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
722
723 #if !XNN_ARCH_WASM && !XNN_ARCH_ASMJS
724 static void f32_gemm_4x8__psimd_loadsplat(benchmark::State& state, const char* net) {
725 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__psimd_loadsplat, 4, 8, 1, 1);
726 }
727
f32_gemm_6x8__psimd_loadsplat(benchmark::State & state,const char * net)728 static void f32_gemm_6x8__psimd_loadsplat(benchmark::State& state, const char* net) {
729 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__psimd_loadsplat, 6, 8, 1, 1);
730 }
731
f32_gemm_4x8__psimd_splat(benchmark::State & state,const char * net)732 static void f32_gemm_4x8__psimd_splat(benchmark::State& state, const char* net) {
733 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8__psimd_splat, 4, 8, 1, 1);
734 }
735
f32_gemm_6x8__psimd_splat(benchmark::State & state,const char * net)736 static void f32_gemm_6x8__psimd_splat(benchmark::State& state, const char* net) {
737 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8__psimd_splat, 6, 8, 1, 1);
738 }
739
f32_gemm_4x8s4__psimd(benchmark::State & state,const char * net)740 static void f32_gemm_4x8s4__psimd(benchmark::State& state, const char* net) {
741 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x8s4__psimd, 4, 8, 1, 4);
742 }
743
f32_gemm_6x8s4__psimd(benchmark::State & state,const char * net)744 static void f32_gemm_6x8s4__psimd(benchmark::State& state, const char* net) {
745 GEMMBenchmark(state, xnn_f32_gemm_ukernel_6x8s4__psimd, 6, 8, 1, 4);
746 }
747
f32_ppmm_4x8_unipass__psimd(benchmark::State & state,const char * net)748 static void f32_ppmm_4x8_unipass__psimd(benchmark::State& state, const char* net) {
749 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__psimd, xnn_x32_packx_ukernel_4x__psimd, 4, 8);
750 }
751
f32_ppmm_4x8_twopass__psimd(benchmark::State & state,const char * net)752 static void f32_ppmm_4x8_twopass__psimd(benchmark::State& state, const char* net) {
753 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x8__psimd, xnn_x32_packx_ukernel_4x__psimd, 4, 8);
754 }
755
756 BENCHMARK_GEMM(f32_gemm_4x8__psimd_loadsplat)
BENCHMARK_GEMM(f32_gemm_6x8__psimd_loadsplat)757 BENCHMARK_GEMM(f32_gemm_6x8__psimd_loadsplat)
758 BENCHMARK_GEMM(f32_gemm_4x8__psimd_splat)
759 BENCHMARK_GEMM(f32_gemm_6x8__psimd_splat)
760 BENCHMARK_GEMM(f32_gemm_4x8s4__psimd)
761 BENCHMARK_GEMM(f32_gemm_6x8s4__psimd)
762 BENCHMARK_GEMM(f32_ppmm_4x8_unipass__psimd)
763 BENCHMARK_GEMM(f32_ppmm_4x8_twopass__psimd)
764 #endif // !XNN_ARCH_WASM && !XNN_ARCH_ASMJS
765
766 static void f32_gemm_1x4__scalar(benchmark::State& state, const char* net) {
767 GEMMBenchmark(state, xnn_f32_gemm_ukernel_1x4__scalar, 1, 4, 1, 1);
768 }
769
f32_gemm_2x4__scalar(benchmark::State & state,const char * net)770 static void f32_gemm_2x4__scalar(benchmark::State& state, const char* net) {
771 GEMMBenchmark(state, xnn_f32_gemm_ukernel_2x4__scalar, 2, 4, 1, 1);
772 }
773
f32_gemm_4x4__scalar(benchmark::State & state,const char * net)774 static void f32_gemm_4x4__scalar(benchmark::State& state, const char* net) {
775 GEMMBenchmark(state, xnn_f32_gemm_ukernel_4x4__scalar, 4, 4, 1, 1);
776 }
777
f32_ppmm_2x4_unipass__scalar(benchmark::State & state,const char * net)778 static void f32_ppmm_2x4_unipass__scalar(benchmark::State& state, const char* net) {
779 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_2x4__scalar, xnn_x32_packx_ukernel_2x__scalar, 2, 4);
780 }
781
f32_ppmm_4x2_unipass__scalar(benchmark::State & state,const char * net)782 static void f32_ppmm_4x2_unipass__scalar(benchmark::State& state, const char* net) {
783 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x2__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 2);
784 }
785
f32_ppmm_4x4_unipass__scalar(benchmark::State & state,const char * net)786 static void f32_ppmm_4x4_unipass__scalar(benchmark::State& state, const char* net) {
787 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_4x4__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 4);
788 }
789
f32_ppmm_3x3_unipass__scalar(benchmark::State & state,const char * net)790 static void f32_ppmm_3x3_unipass__scalar(benchmark::State& state, const char* net) {
791 PPMM1PBenchmark(state, xnn_f32_ppmm_ukernel_3x3__scalar, xnn_x32_packx_ukernel_3x__scalar, 3, 3);
792 }
793
f32_ppmm_2x4_twopass__scalar(benchmark::State & state,const char * net)794 static void f32_ppmm_2x4_twopass__scalar(benchmark::State& state, const char* net) {
795 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_2x4__scalar, xnn_x32_packx_ukernel_2x__scalar, 2, 4);
796 }
797
f32_ppmm_4x2_twopass__scalar(benchmark::State & state,const char * net)798 static void f32_ppmm_4x2_twopass__scalar(benchmark::State& state, const char* net) {
799 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x2__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 2);
800 }
801
f32_ppmm_4x4_twopass__scalar(benchmark::State & state,const char * net)802 static void f32_ppmm_4x4_twopass__scalar(benchmark::State& state, const char* net) {
803 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_4x4__scalar, xnn_x32_packx_ukernel_4x__scalar, 4, 4);
804 }
805
f32_ppmm_3x3_twopass__scalar(benchmark::State & state,const char * net)806 static void f32_ppmm_3x3_twopass__scalar(benchmark::State& state, const char* net) {
807 PPMM2PBenchmark(state, xnn_f32_ppmm_ukernel_3x3__scalar, xnn_x32_packx_ukernel_3x__scalar, 3, 3);
808 }
809
810 BENCHMARK_GEMM(f32_gemm_1x4__scalar)
811 BENCHMARK_GEMM(f32_gemm_2x4__scalar)
812 BENCHMARK_GEMM(f32_gemm_4x4__scalar)
813
814 BENCHMARK_GEMM(f32_ppmm_2x4_unipass__scalar)
815 BENCHMARK_GEMM(f32_ppmm_4x2_unipass__scalar)
816 BENCHMARK_GEMM(f32_ppmm_4x4_unipass__scalar)
817 BENCHMARK_GEMM(f32_ppmm_3x3_unipass__scalar)
818
819 BENCHMARK_GEMM(f32_ppmm_2x4_twopass__scalar)
820 BENCHMARK_GEMM(f32_ppmm_4x2_twopass__scalar)
821 BENCHMARK_GEMM(f32_ppmm_4x4_twopass__scalar)
822 BENCHMARK_GEMM(f32_ppmm_3x3_twopass__scalar)
823
824 #ifdef BENCHMARK_RUY
825 BENCHMARK_GEMM(ruy_st)
826 #endif // BENCHMARK_RUY
827
828 #ifndef XNNPACK_BENCHMARK_NO_MAIN
829 BENCHMARK_MAIN();
830 #endif
831