1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <benchmark/benchmark.h>
12
13 #define BENCHMARK_GEMM(gemm_fn) \
14 BENCHMARK_CAPTURE(gemm_fn, mobilenet_v1, "MobileNet v1")->Apply(MobileNetV1GemmArguments)->UseRealTime(); \
15 BENCHMARK_CAPTURE(gemm_fn, mobilenet_v2, "MobileNet v2")->Apply(MobileNetV2GemmArguments)->UseRealTime(); \
16 BENCHMARK_CAPTURE(gemm_fn, mobilenet_v3_small, "MobileNet v3 Small")->Apply(MobileNetV3SmallGemmArguments)->UseRealTime(); \
17 BENCHMARK_CAPTURE(gemm_fn, mobilenet_v3_large, "MobileNet v3 Large")->Apply(MobileNetV3LargeGemmArguments)->UseRealTime(); \
18 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1GemmArguments)->UseRealTime(); \
19 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2GemmArguments)->UseRealTime(); \
20 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3GemmArguments)->UseRealTime(); \
21 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4GemmArguments)->UseRealTime(); \
22 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8GemmArguments)->UseRealTime(); \
23 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x05, "ShuffleNet v2 0.5X")->Apply(ShuffleNetV2X05GemmArguments)->UseRealTime(); \
24 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x10, "ShuffleNet v2 1.0X")->Apply(ShuffleNetV2X10GemmArguments)->UseRealTime(); \
25 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x15, "ShuffleNet v2 1.5X")->Apply(ShuffleNetV2X15GemmArguments)->UseRealTime(); \
26 BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x20, "ShuffleNet v2 2.0X")->Apply(ShuffleNetV2X20GemmArguments)->UseRealTime(); \
27 BENCHMARK_CAPTURE(gemm_fn, inception_v3, "Inception v3")->Apply(InceptionV3GemmArguments)->UseRealTime(); \
28 BENCHMARK_CAPTURE(gemm_fn, resnet18, "ResNet-18")->Apply(ResNet18GemmArguments)->UseRealTime(); \
29 BENCHMARK_CAPTURE(gemm_fn, resnet50, "ResNet-50")->Apply(ResNet50GemmArguments)->UseRealTime(); \
30 BENCHMARK_CAPTURE(gemm_fn, squeezenet_v10, "SqueezeNet 1.0")->Apply(SqueezeNetV10GemmArguments)->UseRealTime(); \
31 BENCHMARK_CAPTURE(gemm_fn, squeezenet_v11, "SqueezeNet 1.1")->Apply(SqueezeNetV11GemmArguments)->UseRealTime(); \
32 BENCHMARK_CAPTURE(gemm_fn, vgg, "VGG")->Apply(VGGGemmArguments)->UseRealTime(); \
33 BENCHMARK_CAPTURE(gemm_fn, srcnn915, "SRCNN (9-1-5)")->Apply(SRCNN915GemmArguments)->UseRealTime(); \
34 BENCHMARK_CAPTURE(gemm_fn, srcnn935, "SRCNN (9-3-5)")->Apply(SRCNN935GemmArguments)->UseRealTime();
35
36 // Removed due to OOM SEGFAULT on 32 bit ARM.
37 // BENCHMARK_CAPTURE(gemm_fn, srcnn955, "SRCNN (9-5-5)")->Apply(SRCNN955GemmArguments)->UseRealTime();
38
39
40 // ShuffleNet v1 with 1 group.
ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark * b)41 static void ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark* b) {
42 b->ArgNames({"M", "N", "K"});
43
44 /* M N K */
45 b->Args({112 * 112, 24, 3 * 3 * 3});
46 b->Args({ 56 * 56, 36, 24 * 1 * 1});
47 b->Args({ 28 * 28, 120, 36 * 1 * 1});
48 b->Args({ 28 * 28, 36, 144 * 1 * 1});
49 b->Args({ 28 * 28, 144, 36 * 1 * 1});
50 b->Args({ 28 * 28, 72, 144 * 1 * 1});
51 b->Args({ 14 * 14, 144, 72 * 1 * 1});
52 b->Args({ 14 * 14, 72, 288 * 1 * 1});
53 b->Args({ 14 * 14, 288, 72 * 1 * 1});
54 b->Args({ 14 * 14, 144, 288 * 1 * 1});
55 b->Args({ 7 * 7, 288, 144 * 1 * 1});
56 b->Args({ 7 * 7, 144, 576 * 1 * 1});
57 b->Args({ 7 * 7, 576, 144 * 1 * 1});
58 }
59
60 // ShuffleNet v1 with 2 groups.
ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark * b)61 static void ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark* b) {
62 b->ArgNames({"M", "N", "K"});
63
64 /* M N K */
65 b->Args({112 * 112, 24, 3 * 3 * 3});
66 b->Args({ 56 * 56, 50, 24 * 1 * 1});
67 b->Args({ 28 * 28, 88, 25 * 1 * 1});
68 b->Args({ 28 * 28, 25, 100 * 1 * 1});
69 b->Args({ 28 * 28, 100, 25 * 1 * 1});
70 b->Args({ 28 * 28, 50, 100 * 1 * 1});
71 b->Args({ 14 * 14, 100, 50 * 1 * 1});
72 b->Args({ 14 * 14, 50, 200 * 1 * 1});
73 b->Args({ 14 * 14, 200, 50 * 1 * 1});
74 b->Args({ 14 * 14, 100, 200 * 1 * 1});
75 b->Args({ 7 * 7, 200, 100 * 1 * 1});
76 b->Args({ 7 * 7, 100, 400 * 1 * 1});
77 b->Args({ 7 * 7, 400, 100 * 1 * 1});
78 }
79
80 // ShuffleNet v1 with 3 groups.
ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark * b)81 static void ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark* b) {
82 b->ArgNames({"M", "N", "K"});
83
84 /* M N K */
85 b->Args({112 * 112, 24, 3 * 3 * 3});
86 b->Args({ 56 * 56, 60, 24 * 1 * 1});
87 b->Args({ 28 * 28, 72, 20 * 1 * 1});
88 b->Args({ 28 * 28, 20, 80 * 1 * 1});
89 b->Args({ 28 * 28, 80, 20 * 1 * 1});
90 b->Args({ 28 * 28, 40, 80 * 1 * 1});
91 b->Args({ 14 * 14, 80, 40 * 1 * 1});
92 b->Args({ 14 * 14, 40, 160 * 1 * 1});
93 b->Args({ 14 * 14, 160, 40 * 1 * 1});
94 b->Args({ 14 * 14, 80, 160 * 1 * 1});
95 b->Args({ 7 * 7, 160, 80 * 1 * 1});
96 b->Args({ 7 * 7, 80, 320 * 1 * 1});
97 b->Args({ 7 * 7, 320, 80 * 1 * 1});
98 }
99
100 // ShuffleNet v1 with 4 groups.
ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark * b)101 static void ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark* b) {
102 b->ArgNames({"M", "N", "K"});
103
104 /* M N K */
105 b->Args({112 * 112, 24, 3 * 3 * 3});
106 b->Args({ 56 * 56, 68, 24 * 1 * 1});
107 b->Args({ 28 * 28, 62, 17 * 1 * 1});
108 b->Args({ 28 * 28, 17, 68 * 1 * 1});
109 b->Args({ 28 * 28, 68, 17 * 1 * 1});
110 b->Args({ 28 * 28, 34, 68 * 1 * 1});
111 b->Args({ 14 * 14, 68, 34 * 1 * 1});
112 b->Args({ 14 * 14, 34, 136 * 1 * 1});
113 b->Args({ 14 * 14, 136, 34 * 1 * 1});
114 b->Args({ 14 * 14, 68, 136 * 1 * 1});
115 b->Args({ 7 * 7, 136, 68 * 1 * 1});
116 b->Args({ 7 * 7, 68, 272 * 1 * 1});
117 b->Args({ 7 * 7, 272, 68 * 1 * 1});
118 }
119
120 // ShuffleNet v1 with 8 groups.
ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark * b)121 static void ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark* b) {
122 b->ArgNames({"M", "N", "K"});
123
124 /* M N K */
125 b->Args({112 * 112, 24, 3 * 3 * 3});
126 b->Args({ 56 * 56, 96, 24 * 1 * 1});
127 b->Args({ 28 * 28, 45, 12 * 1 * 1});
128 b->Args({ 28 * 28, 12, 48 * 1 * 1});
129 b->Args({ 28 * 28, 48, 12 * 1 * 1});
130 b->Args({ 28 * 28, 24, 48 * 1 * 1});
131 b->Args({ 14 * 14, 48, 24 * 1 * 1});
132 b->Args({ 14 * 14, 24, 96 * 1 * 1});
133 b->Args({ 14 * 14, 96, 24 * 1 * 1});
134 b->Args({ 14 * 14, 48, 96 * 1 * 1});
135 b->Args({ 7 * 7, 96, 48 * 1 * 1});
136 b->Args({ 7 * 7, 48, 192 * 1 * 1});
137 b->Args({ 7 * 7, 192, 48 * 1 * 1});
138 }
139
140 // ShuffleNet v2 (0.5X scale)
ShuffleNetV2X05GemmArguments(benchmark::internal::Benchmark * b)141 static void ShuffleNetV2X05GemmArguments(benchmark::internal::Benchmark* b) {
142 b->ArgNames({"M", "N", "K"});
143
144 /* M N K */
145 b->Args({112 * 112, 24, 3 * 3 * 3});
146 b->Args({ 56 * 56, 24, 24 * 1 * 1});
147 b->Args({ 28 * 28, 24, 24 * 1 * 1});
148 b->Args({ 28 * 28, 48, 48 * 1 * 1});
149 b->Args({ 14 * 14, 48, 48 * 1 * 1});
150 b->Args({ 14 * 14, 96, 96 * 1 * 1});
151 b->Args({ 7 * 7, 96, 96 * 1 * 1});
152 b->Args({ 7 * 7, 1024, 192 * 1 * 1});
153 }
154
155 // ShuffleNet v2 (1.0X scale)
ShuffleNetV2X10GemmArguments(benchmark::internal::Benchmark * b)156 static void ShuffleNetV2X10GemmArguments(benchmark::internal::Benchmark* b) {
157 b->ArgNames({"M", "N", "K"});
158
159 /* M N K */
160 b->Args({112 * 112, 24, 3 * 3 * 3});
161 b->Args({ 56 * 56, 58, 24 * 1 * 1});
162 b->Args({ 28 * 28, 58, 24 * 1 * 1});
163 b->Args({ 28 * 28, 58, 58 * 1 * 1});
164 b->Args({ 14 * 14, 116, 116 * 1 * 1});
165 b->Args({ 14 * 14, 116, 116 * 1 * 1});
166 b->Args({ 14 * 14, 232, 232 * 1 * 1});
167 b->Args({ 7 * 7, 232, 232 * 1 * 1});
168 b->Args({ 7 * 7, 1024, 464 * 1 * 1});
169 }
170
171 // ShuffleNet v2 (1.5X scale)
ShuffleNetV2X15GemmArguments(benchmark::internal::Benchmark * b)172 static void ShuffleNetV2X15GemmArguments(benchmark::internal::Benchmark* b) {
173 b->ArgNames({"M", "N", "K"});
174
175 /* M N K */
176 b->Args({112 * 112, 24, 3 * 3 * 3});
177 b->Args({ 56 * 56, 88, 24 * 1 * 1});
178 b->Args({ 28 * 28, 88, 24 * 1 * 1});
179 b->Args({ 28 * 28, 88, 88 * 1 * 1});
180 b->Args({ 28 * 28, 176, 176 * 1 * 1});
181 b->Args({ 14 * 14, 176, 176 * 1 * 1});
182 b->Args({ 14 * 14, 352, 352 * 1 * 1});
183 b->Args({ 7 * 7, 352, 352 * 1 * 1});
184 b->Args({ 7 * 7, 1024, 704 * 1 * 1});
185 }
186
187 // ShuffleNet v2 (2.0X scale)
ShuffleNetV2X20GemmArguments(benchmark::internal::Benchmark * b)188 static void ShuffleNetV2X20GemmArguments(benchmark::internal::Benchmark* b) {
189 b->ArgNames({"M", "N", "K"});
190
191 /* M N K */
192 b->Args({112 * 112, 24, 3 * 3 * 3});
193 b->Args({ 56 * 56, 122, 24 * 1 * 1});
194 b->Args({ 28 * 28, 122, 24 * 1 * 1});
195 b->Args({ 28 * 28, 122, 122 * 1 * 1});
196 b->Args({ 28 * 28, 244, 244 * 1 * 1});
197 b->Args({ 14 * 14, 244, 244 * 1 * 1});
198 b->Args({ 14 * 14, 488, 488 * 1 * 1});
199 b->Args({ 7 * 7, 488, 488 * 1 * 1});
200 b->Args({ 7 * 7, 2048, 976 * 1 * 1});
201 }
202
MobileNetV1GemmArguments(benchmark::internal::Benchmark * b)203 static void MobileNetV1GemmArguments(benchmark::internal::Benchmark* b) {
204 b->ArgNames({"M", "N", "K"});
205
206 /* M N K */
207 b->Args({112 * 112, 32, 3 * 3 * 3});
208 b->Args({112 * 112, 64, 32 * 1 * 1});
209 b->Args({ 56 * 56, 128, 64 * 1 * 1});
210 b->Args({ 56 * 56, 128, 128 * 1 * 1});
211 b->Args({ 28 * 28, 256, 128 * 1 * 1});
212 b->Args({ 28 * 28, 256, 256 * 1 * 1});
213 b->Args({ 14 * 14, 512, 256 * 1 * 1});
214 b->Args({ 14 * 14, 512, 512 * 1 * 1});
215 b->Args({ 7 * 7, 1024, 512 * 1 * 1});
216 b->Args({ 7 * 7, 1024, 1024 * 1 * 1});
217 }
218
MobileNetV2GemmArguments(benchmark::internal::Benchmark * b)219 static void MobileNetV2GemmArguments(benchmark::internal::Benchmark* b) {
220 b->ArgNames({"M", "N", "K"});
221
222 /*********** Initial Stage ************/
223 /* M N K */
224 b->Args({112 * 112, 32, 3 * 3 * 3});
225 /************ Bottleneck 1 ************/
226 /* M N K */
227 b->Args({112 * 112, 16, 32 * 1 * 1});
228 /************ Bottleneck 2 ************/
229 /* M N K */
230 b->Args({112 * 112, 96, 16 * 1 * 1});
231 b->Args({ 56 * 56, 24, 96 * 1 * 1});
232 b->Args({ 56 * 56, 144, 24 * 1 * 1});
233 b->Args({ 56 * 56, 24, 144 * 1 * 1});
234 /************ Bottleneck 3 ************/
235 /* M N K */
236 b->Args({ 28 * 28, 32, 144 * 1 * 1});
237 b->Args({ 28 * 28, 192, 32 * 1 * 1});
238 b->Args({ 28 * 28, 32, 192 * 1 * 1});
239 /************ Bottleneck 4 ************/
240 /* M N K */
241 b->Args({ 14 * 14, 64, 192 * 1 * 1});
242 b->Args({ 14 * 14, 384, 64 * 1 * 1});
243 b->Args({ 14 * 14, 64, 384 * 1 * 1});
244 /************ Bottleneck 5 ************/
245 /* M N K */
246 b->Args({ 14 * 14, 96, 384 * 1 * 1});
247 b->Args({ 14 * 14, 576, 96 * 1 * 1});
248 b->Args({ 14 * 14, 96, 576 * 1 * 1});
249 /************ Bottleneck 6 ************/
250 /* M N K */
251 b->Args({ 7 * 7, 160, 576 * 1 * 1});
252 b->Args({ 7 * 7, 960, 160 * 1 * 1});
253 b->Args({ 7 * 7, 160, 960 * 1 * 1});
254 /************ Bottleneck 7 ************/
255 /* M N K */
256 b->Args({ 7 * 7, 320, 960 * 1 * 1});
257 /********* Pre-pooling Conv2D *********/
258 /* M N K */
259 b->Args({ 7 * 7, 1280, 320 * 1 * 1});
260 /******** Post-pooling Conv2D *********/
261 /* M N K */
262 b->Args({ 1 * 1, 1000, 1280 * 1 * 1});
263 }
264
MobileNetV3SmallGemmArguments(benchmark::internal::Benchmark * b)265 static void MobileNetV3SmallGemmArguments(benchmark::internal::Benchmark* b) {
266 b->ArgNames({"M", "N", "K"});
267
268 /************ Initial Stage ************/
269 /* M N K */
270 b->Args({112 * 112, 16, 3 * 3 * 3});
271 /************* Bottleneck 1 ************/
272 /* M N K */
273 b->Args({ 1 * 1, 8, 16 * 1 * 1});
274 b->Args({ 1 * 1, 16, 8 * 1 * 1});
275 b->Args({ 56 * 56, 16, 16 * 1 * 1});
276 /************* Bottleneck 2 ************/
277 /* M N K */
278 b->Args({ 56 * 56, 72, 16 * 1 * 1});
279 b->Args({ 28 * 28, 24, 72 * 1 * 1});
280 /************* Bottleneck 3 ************/
281 /* M N K */
282 b->Args({ 28 * 28, 88, 24 * 1 * 1});
283 b->Args({ 28 * 28, 24, 88 * 1 * 1});
284 /************* Bottleneck 4 ************/
285 /* M N K */
286 b->Args({ 28 * 28, 96, 24 * 1 * 1});
287 b->Args({ 1 * 1, 24, 96 * 1 * 1});
288 b->Args({ 1 * 1, 96, 24 * 1 * 1});
289 b->Args({ 14 * 14, 40, 96 * 1 * 1});
290 /************* Bottleneck 5 ************/
291 /* M N K */
292 b->Args({ 14 * 14, 240, 40 * 1 * 1});
293 b->Args({ 1 * 1, 64, 240 * 1 * 1});
294 b->Args({ 1 * 1, 240, 64 * 1 * 1});
295 b->Args({ 14 * 14, 40, 240 * 1 * 1});
296 /************* Bottleneck 6 ************/
297 /* M N K */
298 //b->Args({ 14 * 14, 240, 40 * 1 * 1});
299 //b->Args({ 1 * 1, 64, 240 * 1 * 1});
300 //b->Args({ 1 * 1, 240, 64 * 1 * 1});
301 //b->Args({ 14 * 14, 40, 240 * 1 * 1});
302 /************* Bottleneck 7 ************/
303 /* M N K */
304 b->Args({ 14 * 14, 120, 40 * 1 * 1});
305 b->Args({ 1 * 1, 32, 120 * 1 * 1});
306 b->Args({ 1 * 1, 120, 32 * 1 * 1});
307 b->Args({ 14 * 14, 48, 120 * 1 * 1});
308 /************* Bottleneck 8 ************/
309 /* M N K */
310 b->Args({ 14 * 14, 144, 48 * 1 * 1});
311 b->Args({ 1 * 1, 40, 144 * 1 * 1});
312 b->Args({ 1 * 1, 144, 40 * 1 * 1});
313 b->Args({ 14 * 14, 48, 144 * 1 * 1});
314 /************* Bottleneck 9 ************/
315 /* M N K */
316 b->Args({ 14 * 14, 288, 48 * 1 * 1});
317 b->Args({ 1 * 1, 72, 288 * 1 * 1});
318 b->Args({ 1 * 1, 288, 72 * 1 * 1});
319 b->Args({ 7 * 7, 96, 288 * 1 * 1});
320 /************ Bottleneck 10 ************/
321 /* M N K */
322 b->Args({ 7 * 7, 576, 96 * 1 * 1});
323 b->Args({ 1 * 1, 144, 576 * 1 * 1});
324 b->Args({ 1 * 1, 576, 144 * 1 * 1});
325 b->Args({ 7 * 7, 96, 576 * 1 * 1});
326 /************ Bottleneck 11 ************/
327 /* M N K */
328 //b->Args({ 7 * 7, 576, 96 * 1 * 1});
329 //b->Args({ 1 * 1, 144, 576 * 1 * 1});
330 //b->Args({ 1 * 1, 576, 144 * 1 * 1});
331 //b->Args({ 7 * 7, 96, 576 * 1 * 1});
332 /************* Last Stage *************/
333 /* M N K */
334 //b->Args({ 7 * 7, 576, 96 * 1 * 1});
335 b->Args({ 1 * 1, 1024, 576 * 1 * 1});
336 b->Args({ 1 * 1, 1001, 1024 * 1 * 1});
337 }
338
MobileNetV3LargeGemmArguments(benchmark::internal::Benchmark * b)339 static void MobileNetV3LargeGemmArguments(benchmark::internal::Benchmark* b) {
340 b->ArgNames({"M", "N", "K"});
341
342 /************ Initial Stage ************/
343 /* M N K */
344 b->Args({112 * 112, 16, 3 * 3 * 3});
345 /************* Bottleneck 1 ************/
346 /* M N K */
347 b->Args({112 * 112, 16, 16 * 1 * 1});
348 /************* Bottleneck 2 ************/
349 /* M N K */
350 b->Args({112 * 112, 64, 16 * 1 * 1});
351 b->Args({ 56 * 56, 24, 64 * 1 * 1});
352 /************* Bottleneck 3 ************/
353 /* M N K */
354 b->Args({ 56 * 56, 72, 24 * 1 * 1});
355 b->Args({ 56 * 56, 24, 72 * 1 * 1});
356 /************* Bottleneck 4 ************/
357 /* M N K */
358 //b->Args({ 56 * 56, 72, 24 * 1 * 1});
359 b->Args({ 1 * 1, 24, 72 * 1 * 1});
360 b->Args({ 1 * 1, 72, 24 * 1 * 1});
361 b->Args({ 28 * 28, 40, 72 * 1 * 1});
362 /************* Bottleneck 5 ************/
363 /* M N K */
364 b->Args({ 28 * 28, 120, 40 * 1 * 1});
365 b->Args({ 1 * 1, 32, 120 * 1 * 1});
366 b->Args({ 1 * 1, 120, 32 * 1 * 1});
367 b->Args({ 28 * 28, 40, 120 * 1 * 1});
368 /************* Bottleneck 6 ************/
369 /* M N K */
370 //b->Args({ 28 * 28, 120, 40 * 1 * 1});
371 //b->Args({ 1 * 1, 32, 120 * 1 * 1});
372 //b->Args({ 1 * 1, 120, 32 * 1 * 1});
373 //b->Args({ 28 * 28, 40, 120 * 1 * 1});
374 /************* Bottleneck 7 ************/
375 /* M N K */
376 b->Args({ 28 * 28, 240, 40 * 1 * 1});
377 b->Args({ 14 * 14, 80, 240 * 1 * 1});
378 /************* Bottleneck 8 ************/
379 /* M N K */
380 b->Args({ 14 * 14, 200, 80 * 1 * 1});
381 b->Args({ 14 * 14, 80, 200 * 1 * 1});
382 /************* Bottleneck 9 ************/
383 /* M N K */
384 b->Args({ 14 * 14, 184, 80 * 1 * 1});
385 b->Args({ 14 * 14, 80, 184 * 1 * 1});
386 /************ Bottleneck 10 ************/
387 /* M N K */
388 b->Args({ 14 * 14, 184, 80 * 1 * 1});
389 b->Args({ 14 * 14, 80, 184 * 1 * 1});
390 /************ Bottleneck 11 ************/
391 /* M N K */
392 b->Args({ 14 * 14, 480, 80 * 1 * 1});
393 b->Args({ 1 * 1, 120, 480 * 1 * 1});
394 b->Args({ 1 * 1, 480, 120 * 1 * 1});
395 b->Args({ 14 * 14, 112, 480 * 1 * 1});
396 /************ Bottleneck 12 ************/
397 /* M N K */
398 b->Args({ 14 * 14, 672, 112 * 1 * 1});
399 b->Args({ 1 * 1, 168, 672 * 1 * 1});
400 b->Args({ 1 * 1, 672, 168 * 1 * 1});
401 b->Args({ 14 * 14, 112, 672 * 1 * 1});
402 /************ Bottleneck 13 ************/
403 /* M N K */
404 //b->Args({ 14 * 14, 672, 112 * 1 * 1});
405 //b->Args({ 1 * 1, 168, 672 * 1 * 1});
406 //b->Args({ 1 * 1, 672, 168 * 1 * 1});
407 b->Args({ 7 * 7, 160, 672 * 1 * 1});
408 /************ Bottleneck 14 ************/
409 /* M N K */
410 b->Args({ 7 * 7, 960, 160 * 1 * 1});
411 b->Args({ 1 * 1, 240, 960 * 1 * 1});
412 b->Args({ 1 * 1, 960, 240 * 1 * 1});
413 b->Args({ 7 * 7, 160, 960 * 1 * 1});
414 /************ Bottleneck 15 ************/
415 /* M N K */
416 //b->Args({ 7 * 7, 960, 160 * 1 * 1});
417 //b->Args({ 1 * 1, 240, 960 * 1 * 1});
418 //b->Args({ 1 * 1, 960, 240 * 1 * 1});
419 //b->Args({ 7 * 7, 160, 960 * 1 * 1});
420 /************* Last Stage *************/
421 /* M N K */
422 //b->Args({ 7 * 7, 960, 160 * 1 * 1});
423 b->Args({ 1 * 1, 1280, 960 * 1 * 1});
424 b->Args({ 1 * 1, 1001, 1280 * 1 * 1});
425 }
426
427 // SqueezeNet 1.0
SqueezeNetV10GemmArguments(benchmark::internal::Benchmark * b)428 static void SqueezeNetV10GemmArguments(benchmark::internal::Benchmark* b) {
429 b->ArgNames({"M", "N", "K"});
430
431 /************** Conv 1 ***************/
432 /* M N K */
433 b->Args({111 * 111, 96, 3 * 7 * 7});
434 /************** Fire 2 ***************/
435 /* M N K */
436 b->Args({ 55 * 55, 16, 96 * 1 * 1});
437 b->Args({ 55 * 55, 64, 16 * 1 * 1});
438 b->Args({ 55 * 55, 64, 16 * 3 * 3});
439 /************** Fire 3 ***************/
440 /* M N K */
441 b->Args({ 55 * 55, 16, 128 * 1 * 1});
442 /************** Fire 4 ***************/
443 /* M N K */
444 b->Args({ 55 * 55, 32, 128 * 1 * 1});
445 b->Args({ 55 * 55, 128, 32 * 1 * 1});
446 b->Args({ 55 * 55, 128, 32 * 3 * 3});
447 /************** Fire 5 ***************/
448 /* M N K */
449 b->Args({ 27 * 27, 32, 256 * 1 * 1});
450 b->Args({ 27 * 27, 128, 32 * 1 * 1});
451 b->Args({ 27 * 27, 128, 32 * 3 * 3});
452 /************** Fire 6 ***************/
453 /* M N K */
454 b->Args({ 27 * 27, 48, 256 * 1 * 1});
455 b->Args({ 27 * 27, 192, 48 * 1 * 1});
456 b->Args({ 27 * 27, 192, 48 * 3 * 3});
457 /************** Fire 7 ***************/
458 /* M N K */
459 b->Args({ 27 * 27, 48, 384 * 1 * 1});
460 /************** Fire 8 ***************/
461 /* M N K */
462 b->Args({ 27 * 27, 64, 384 * 1 * 1});
463 b->Args({ 27 * 27, 256, 64 * 1 * 1});
464 b->Args({ 27 * 27, 256, 64 * 3 * 3});
465 /************** Fire 9 ***************/
466 /* M N K */
467 b->Args({ 13 * 13, 64, 512 * 1 * 1});
468 b->Args({ 13 * 13, 256, 64 * 1 * 1});
469 b->Args({ 13 * 13, 256, 64 * 3 * 3});
470 /************** Conv 10 **************/
471 /* M N K */
472 b->Args({ 13 * 13, 1000, 512 * 1 * 1});
473 }
474
475 // SqueezeNet 1.1
SqueezeNetV11GemmArguments(benchmark::internal::Benchmark * b)476 static void SqueezeNetV11GemmArguments(benchmark::internal::Benchmark* b) {
477 b->ArgNames({"M", "N", "K"});
478
479 /************** Conv 1 ***************/
480 /* M N K */
481 b->Args({111 * 111, 64, 3 * 3 * 3});
482 /************** Fire 2 ***************/
483 /* M N K */
484 b->Args({ 55 * 55, 16, 64 * 1 * 1});
485 b->Args({ 55 * 55, 64, 16 * 1 * 1});
486 b->Args({ 55 * 55, 64, 16 * 3 * 3});
487 /************** Fire 3 ***************/
488 /* M N K */
489 b->Args({ 55 * 55, 16, 128 * 1 * 1});
490 /************** Fire 4 ***************/
491 /* M N K */
492 b->Args({ 27 * 27, 32, 128 * 1 * 1});
493 b->Args({ 27 * 27, 128, 32 * 1 * 1});
494 b->Args({ 27 * 27, 128, 32 * 3 * 3});
495 /************** Fire 5 ***************/
496 /* M N K */
497 b->Args({ 27 * 27, 32, 256 * 1 * 1});
498 /************** Fire 6 ***************/
499 /* M N K */
500 b->Args({ 13 * 13, 48, 256 * 1 * 1});
501 b->Args({ 13 * 13, 192, 48 * 1 * 1});
502 b->Args({ 13 * 13, 192, 48 * 3 * 3});
503 /************** Fire 7 ***************/
504 /* M N K */
505 b->Args({ 13 * 13, 48, 384 * 1 * 1});
506 /************** Fire 8 ***************/
507 /* M N K */
508 b->Args({ 13 * 13, 64, 384 * 1 * 1});
509 b->Args({ 13 * 13, 256, 64 * 1 * 1});
510 b->Args({ 13 * 13, 256, 64 * 3 * 3});
511 /************** Fire 9 ***************/
512 /* M N K */
513 b->Args({ 13 * 13, 64, 512 * 1 * 1});
514 /************** Conv 10 **************/
515 /* M N K */
516 b->Args({ 13 * 13, 1000, 512 * 1 * 1});
517 }
518
InceptionV3GemmArguments(benchmark::internal::Benchmark * b)519 static void InceptionV3GemmArguments(benchmark::internal::Benchmark* b) {
520 /* M N K */
521 b->Args({150 * 150, 32, 3 * 3 * 3});
522 b->Args({149 * 149, 32, 32 * 3 * 3});
523 b->Args({149 * 149, 64, 32 * 3 * 3});
524 b->Args({ 75 * 75, 80, 64 * 1 * 1});
525 b->Args({ 73 * 73, 192, 80 * 3 * 3});
526 b->Args({ 37 * 37, 64, 192 * 1 * 1});
527 b->Args({ 37 * 37, 48, 192 * 1 * 1});
528 b->Args({ 37 * 37, 64, 48 * 5 * 5});
529 b->Args({ 37 * 37, 96, 64 * 3 * 3});
530 b->Args({ 37 * 37, 96, 96 * 3 * 3});
531 b->Args({ 37 * 37, 32, 192 * 1 * 1});
532 b->Args({ 37 * 37, 64, 256 * 1 * 1});
533 b->Args({ 37 * 37, 48, 256 * 1 * 1});
534 b->Args({ 37 * 37, 64, 288 * 1 * 1});
535 b->Args({ 37 * 37, 48, 288 * 1 * 1});
536 b->Args({ 18 * 18, 384, 288 * 3 * 3});
537 b->Args({ 18 * 18, 96, 96 * 3 * 3});
538 b->Args({ 19 * 19, 192, 768 * 1 * 1});
539 b->Args({ 19 * 19, 128, 768 * 1 * 1});
540 b->Args({ 19 * 19, 128, 128 * 1 * 7});
541 b->Args({ 19 * 19, 192, 128 * 7 * 1});
542 b->Args({ 19 * 19, 128, 128 * 7 * 1});
543 b->Args({ 19 * 19, 192, 128 * 1 * 7});
544 b->Args({ 19 * 19, 160, 768 * 1 * 1});
545 b->Args({ 19 * 19, 160, 160 * 1 * 7});
546 b->Args({ 19 * 19, 192, 160 * 7 * 1});
547 b->Args({ 19 * 19, 160, 160 * 7 * 1});
548 b->Args({ 19 * 19, 192, 160 * 1 * 7});
549 b->Args({ 19 * 19, 192, 192 * 1 * 7});
550 b->Args({ 19 * 19, 192, 192 * 7 * 1});
551 b->Args({ 9 * 9, 320, 192 * 3 * 3});
552 b->Args({ 9 * 9, 192, 192 * 3 * 3});
553 b->Args({ 10 * 10, 320, 1280 * 1 * 1});
554 b->Args({ 10 * 10, 384, 1280 * 1 * 1});
555 b->Args({ 10 * 10, 384, 384 * 1 * 3});
556 b->Args({ 10 * 10, 384, 384 * 3 * 1});
557 b->Args({ 10 * 10, 448, 1280 * 1 * 1});
558 b->Args({ 10 * 10, 384, 448 * 3 * 3});
559 b->Args({ 10 * 10, 192, 1280 * 1 * 1});
560 b->Args({ 10 * 10, 320, 2048 * 1 * 1});
561 b->Args({ 10 * 10, 384, 2048 * 1 * 1});
562 b->Args({ 10 * 10, 448, 2048 * 1 * 1});
563 b->Args({ 10 * 10, 192, 2048 * 1 * 1});
564 b->Args({ 3 * 3, 1001, 2048 * 1 * 1});
565 }
566
ResNet18GemmArguments(benchmark::internal::Benchmark * b)567 static void ResNet18GemmArguments(benchmark::internal::Benchmark* b) {
568 b->ArgNames({"M", "N", "K"});
569
570 /* M N K */
571 b->Args({112 * 112, 64, 3 * 7 * 7});
572 b->Args({ 56 * 56, 64, 64 * 3 * 3});
573 b->Args({ 28 * 28, 128, 64 * 3 * 3});
574 b->Args({ 28 * 28, 128, 128 * 3 * 3});
575 b->Args({ 28 * 28, 128, 64 * 1 * 1});
576 b->Args({ 14 * 14, 256, 128 * 3 * 3});
577 b->Args({ 14 * 14, 256, 256 * 3 * 3});
578 b->Args({ 14 * 14, 256, 128 * 1 * 1});
579 b->Args({ 7 * 7, 512, 256 * 3 * 3});
580 b->Args({ 7 * 7, 512, 512 * 3 * 3});
581 b->Args({ 7 * 7, 512, 256 * 1 * 1});
582 }
583
ResNet50GemmArguments(benchmark::internal::Benchmark * b)584 static void ResNet50GemmArguments(benchmark::internal::Benchmark* b) {
585 b->ArgNames({"M", "N", "K"});
586
587 /*************** Conv 1 ***************/
588 /* M N K */
589 b->Args({112 * 112, 64, 3 * 7 * 7});
590 /************** Conv 2.X **************/
591 /* M N K */
592 b->Args({ 56 * 56, 64, 64 * 1 * 1});
593 b->Args({ 56 * 56, 64, 64 * 3 * 3});
594 b->Args({ 56 * 56, 256, 64 * 1 * 1});
595 b->Args({ 56 * 56, 64, 256 * 1 * 1});
596 /************** Conv 3.X **************/
597 /* M N K */
598 b->Args({ 56 * 56, 128, 256 * 1 * 1});
599 b->Args({ 28 * 28, 128, 128 * 3 * 3});
600 b->Args({ 28 * 28, 512, 128 * 1 * 1});
601 b->Args({ 28 * 28, 512, 256 * 1 * 1});
602 b->Args({ 28 * 28, 128, 512 * 1 * 1});
603 /************** Conv 4.X **************/
604 /* M N K */
605 b->Args({ 28 * 28, 256, 512 * 1 * 1});
606 b->Args({ 14 * 14, 256, 256 * 3 * 3});
607 b->Args({ 14 * 14, 1024, 256 * 1 * 1});
608 b->Args({ 14 * 14, 1024, 512 * 1 * 1});
609 b->Args({ 14 * 14, 256, 1024 * 1 * 1});
610 /************** Conv 5.X **************/
611 /* M N K */
612 b->Args({ 14 * 14, 512, 1024 * 1 * 1});
613 b->Args({ 7 * 7, 512, 512 * 3 * 3});
614 b->Args({ 7 * 7, 2048, 512 * 1 * 1});
615 b->Args({ 7 * 7, 2048, 1024 * 1 * 1});
616 b->Args({ 7 * 7, 512, 2048 * 1 * 1});
617 }
618
VGGGemmArguments(benchmark::internal::Benchmark * b)619 static void VGGGemmArguments(benchmark::internal::Benchmark* b) {
620 b->ArgNames({"M", "N", "K"});
621
622 /************** Conv 1.1 *************/
623 /* M N K */
624 b->Args({224 * 224, 64, 3 * 3 * 3});
625 /************** Conv 1.2 *************/
626 /* M N K */
627 b->Args({224 * 224, 64, 64 * 3 * 3});
628 /************** Conv 2.1 *************/
629 /* M N K */
630 b->Args({112 * 112, 128, 64 * 3 * 3});
631 /************** Conv 2.2 *************/
632 /* M N K */
633 b->Args({112 * 112, 128, 128 * 3 * 3});
634 /************** Conv 3.1 *************/
635 /* M N K */
636 b->Args({ 56 * 56, 256, 128 * 3 * 3});
637 /************** Conv 3.3 *************/
638 /* M N K */
639 b->Args({ 56 * 56, 256, 256 * 1 * 1});
640 /************** Conv 4.1 *************/
641 /* M N K */
642 b->Args({ 28 * 28, 512, 256 * 3 * 3});
643 /************** Conv 4.2 *************/
644 /* M N K */
645 b->Args({ 28 * 28, 512, 512 * 3 * 3});
646 /************** Conv 4.3 *************/
647 /* M N K */
648 b->Args({ 28 * 28, 512, 512 * 1 * 1});
649 /************** Conv 5.X *************/
650 /* M N K */
651 b->Args({ 14 * 14, 512, 512 * 3 * 3});
652 /************** Conv 5.3 *************/
653 /* M N K */
654 b->Args({ 14 * 14, 512, 512 * 1 * 1});
655 }
656
657 // SRCNN (9-1-5)
SRCNN915GemmArguments(benchmark::internal::Benchmark * b)658 static void SRCNN915GemmArguments(benchmark::internal::Benchmark* b) {
659 b->ArgNames({"M", "N", "K"});
660
661 /* M N K */
662 b->Args({376 * 376, 64, 1 * 9 * 9});
663 b->Args({376 * 376, 32, 64 * 1 * 1});
664 b->Args({372 * 372, 1, 32 * 5 * 5});
665 }
666
667 // SRCNN (9-3-5)
SRCNN935GemmArguments(benchmark::internal::Benchmark * b)668 static void SRCNN935GemmArguments(benchmark::internal::Benchmark* b) {
669 b->ArgNames({"M", "N", "K"});
670
671 /* M N K */
672 b->Args({376 * 376, 64, 1 * 9 * 9});
673 b->Args({374 * 374, 32, 64 * 3 * 3});
674 b->Args({370 * 370, 1, 32 * 5 * 5});
675 }
676
677 // SRCNN (9-5-5)
SRCNN955GemmArguments(benchmark::internal::Benchmark * b)678 static void SRCNN955GemmArguments(benchmark::internal::Benchmark* b) {
679 b->ArgNames({"M", "N", "K"});
680
681 /* M N K */
682 b->Args({376 * 376, 64, 1 * 9 * 9});
683 b->Args({372 * 372, 32, 64 * 5 * 5});
684 b->Args({368 * 368, 1, 32 * 5 * 5});
685 }
686