• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <benchmark/benchmark.h>
12 
13 #define BENCHMARK_GEMM(gemm_fn) \
14   BENCHMARK_CAPTURE(gemm_fn, mobilenet_v1, "MobileNet v1")->Apply(MobileNetV1GemmArguments)->UseRealTime(); \
15   BENCHMARK_CAPTURE(gemm_fn, mobilenet_v2, "MobileNet v2")->Apply(MobileNetV2GemmArguments)->UseRealTime(); \
16   BENCHMARK_CAPTURE(gemm_fn, mobilenet_v3_small, "MobileNet v3 Small")->Apply(MobileNetV3SmallGemmArguments)->UseRealTime(); \
17   BENCHMARK_CAPTURE(gemm_fn, mobilenet_v3_large, "MobileNet v3 Large")->Apply(MobileNetV3LargeGemmArguments)->UseRealTime(); \
18   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1GemmArguments)->UseRealTime(); \
19   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2GemmArguments)->UseRealTime(); \
20   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3GemmArguments)->UseRealTime(); \
21   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4GemmArguments)->UseRealTime(); \
22   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8GemmArguments)->UseRealTime(); \
23   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x05, "ShuffleNet v2 0.5X")->Apply(ShuffleNetV2X05GemmArguments)->UseRealTime(); \
24   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x10, "ShuffleNet v2 1.0X")->Apply(ShuffleNetV2X10GemmArguments)->UseRealTime(); \
25   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x15, "ShuffleNet v2 1.5X")->Apply(ShuffleNetV2X15GemmArguments)->UseRealTime(); \
26   BENCHMARK_CAPTURE(gemm_fn, shufflenet_v2_x20, "ShuffleNet v2 2.0X")->Apply(ShuffleNetV2X20GemmArguments)->UseRealTime(); \
27   BENCHMARK_CAPTURE(gemm_fn, inception_v3, "Inception v3")->Apply(InceptionV3GemmArguments)->UseRealTime(); \
28   BENCHMARK_CAPTURE(gemm_fn, resnet18, "ResNet-18")->Apply(ResNet18GemmArguments)->UseRealTime(); \
29   BENCHMARK_CAPTURE(gemm_fn, resnet50, "ResNet-50")->Apply(ResNet50GemmArguments)->UseRealTime(); \
30   BENCHMARK_CAPTURE(gemm_fn, squeezenet_v10, "SqueezeNet 1.0")->Apply(SqueezeNetV10GemmArguments)->UseRealTime(); \
31   BENCHMARK_CAPTURE(gemm_fn, squeezenet_v11, "SqueezeNet 1.1")->Apply(SqueezeNetV11GemmArguments)->UseRealTime(); \
32   BENCHMARK_CAPTURE(gemm_fn, vgg, "VGG")->Apply(VGGGemmArguments)->UseRealTime(); \
33   BENCHMARK_CAPTURE(gemm_fn, srcnn915, "SRCNN (9-1-5)")->Apply(SRCNN915GemmArguments)->UseRealTime(); \
34   BENCHMARK_CAPTURE(gemm_fn, srcnn935, "SRCNN (9-3-5)")->Apply(SRCNN935GemmArguments)->UseRealTime();
35 
36 // Removed due to OOM SEGFAULT on 32 bit ARM.
37 //  BENCHMARK_CAPTURE(gemm_fn, srcnn955, "SRCNN (9-5-5)")->Apply(SRCNN955GemmArguments)->UseRealTime();
38 
39 
40 // ShuffleNet v1 with 1 group.
ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark * b)41 static void ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark* b) {
42   b->ArgNames({"M", "N", "K"});
43 
44   /*           M       N         K    */
45   b->Args({112 * 112,  24,   3 * 3 * 3});
46   b->Args({ 56 *  56,  36,  24 * 1 * 1});
47   b->Args({ 28 *  28, 120,  36 * 1 * 1});
48   b->Args({ 28 *  28,  36, 144 * 1 * 1});
49   b->Args({ 28 *  28, 144,  36 * 1 * 1});
50   b->Args({ 28 *  28,  72, 144 * 1 * 1});
51   b->Args({ 14 *  14, 144,  72 * 1 * 1});
52   b->Args({ 14 *  14,  72, 288 * 1 * 1});
53   b->Args({ 14 *  14, 288,  72 * 1 * 1});
54   b->Args({ 14 *  14, 144, 288 * 1 * 1});
55   b->Args({  7 *   7, 288, 144 * 1 * 1});
56   b->Args({  7 *   7, 144, 576 * 1 * 1});
57   b->Args({  7 *   7, 576, 144 * 1 * 1});
58 }
59 
60 // ShuffleNet v1 with 2 groups.
ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark * b)61 static void ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark* b) {
62   b->ArgNames({"M", "N", "K"});
63 
64   /*           M       N         K    */
65   b->Args({112 * 112,  24,   3 * 3 * 3});
66   b->Args({ 56 *  56,  50,  24 * 1 * 1});
67   b->Args({ 28 *  28,  88,  25 * 1 * 1});
68   b->Args({ 28 *  28,  25, 100 * 1 * 1});
69   b->Args({ 28 *  28, 100,  25 * 1 * 1});
70   b->Args({ 28 *  28,  50, 100 * 1 * 1});
71   b->Args({ 14 *  14, 100,  50 * 1 * 1});
72   b->Args({ 14 *  14,  50, 200 * 1 * 1});
73   b->Args({ 14 *  14, 200,  50 * 1 * 1});
74   b->Args({ 14 *  14, 100, 200 * 1 * 1});
75   b->Args({  7 *   7, 200, 100 * 1 * 1});
76   b->Args({  7 *   7, 100, 400 * 1 * 1});
77   b->Args({  7 *   7, 400, 100 * 1 * 1});
78 }
79 
80 // ShuffleNet v1 with 3 groups.
ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark * b)81 static void ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark* b) {
82   b->ArgNames({"M", "N", "K"});
83 
84   /*           M       N         K    */
85   b->Args({112 * 112,  24,   3 * 3 * 3});
86   b->Args({ 56 *  56,  60,  24 * 1 * 1});
87   b->Args({ 28 *  28,  72,  20 * 1 * 1});
88   b->Args({ 28 *  28,  20,  80 * 1 * 1});
89   b->Args({ 28 *  28,  80,  20 * 1 * 1});
90   b->Args({ 28 *  28,  40,  80 * 1 * 1});
91   b->Args({ 14 *  14,  80,  40 * 1 * 1});
92   b->Args({ 14 *  14,  40, 160 * 1 * 1});
93   b->Args({ 14 *  14, 160,  40 * 1 * 1});
94   b->Args({ 14 *  14,  80, 160 * 1 * 1});
95   b->Args({  7 *   7, 160,  80 * 1 * 1});
96   b->Args({  7 *   7,  80, 320 * 1 * 1});
97   b->Args({  7 *   7, 320,  80 * 1 * 1});
98 }
99 
100 // ShuffleNet v1 with 4 groups.
ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark * b)101 static void ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark* b) {
102   b->ArgNames({"M", "N", "K"});
103 
104   /*           M       N         K    */
105   b->Args({112 * 112,  24,   3 * 3 * 3});
106   b->Args({ 56 *  56,  68,  24 * 1 * 1});
107   b->Args({ 28 *  28,  62,  17 * 1 * 1});
108   b->Args({ 28 *  28,  17,  68 * 1 * 1});
109   b->Args({ 28 *  28,  68,  17 * 1 * 1});
110   b->Args({ 28 *  28,  34,  68 * 1 * 1});
111   b->Args({ 14 *  14,  68,  34 * 1 * 1});
112   b->Args({ 14 *  14,  34, 136 * 1 * 1});
113   b->Args({ 14 *  14, 136,  34 * 1 * 1});
114   b->Args({ 14 *  14,  68, 136 * 1 * 1});
115   b->Args({  7 *   7, 136,  68 * 1 * 1});
116   b->Args({  7 *   7,  68, 272 * 1 * 1});
117   b->Args({  7 *   7, 272,  68 * 1 * 1});
118 }
119 
120 // ShuffleNet v1 with 8 groups.
ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark * b)121 static void ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark* b) {
122   b->ArgNames({"M", "N", "K"});
123 
124   /*           M       N         K    */
125   b->Args({112 * 112,  24,   3 * 3 * 3});
126   b->Args({ 56 *  56,  96,  24 * 1 * 1});
127   b->Args({ 28 *  28,  45,  12 * 1 * 1});
128   b->Args({ 28 *  28,  12,  48 * 1 * 1});
129   b->Args({ 28 *  28,  48,  12 * 1 * 1});
130   b->Args({ 28 *  28,  24,  48 * 1 * 1});
131   b->Args({ 14 *  14,  48,  24 * 1 * 1});
132   b->Args({ 14 *  14,  24,  96 * 1 * 1});
133   b->Args({ 14 *  14,  96,  24 * 1 * 1});
134   b->Args({ 14 *  14,  48,  96 * 1 * 1});
135   b->Args({  7 *   7,  96,  48 * 1 * 1});
136   b->Args({  7 *   7,  48, 192 * 1 * 1});
137   b->Args({  7 *   7, 192,  48 * 1 * 1});
138 }
139 
140 // ShuffleNet v2 (0.5X scale)
ShuffleNetV2X05GemmArguments(benchmark::internal::Benchmark * b)141 static void ShuffleNetV2X05GemmArguments(benchmark::internal::Benchmark* b) {
142   b->ArgNames({"M", "N", "K"});
143 
144   /*           M        N         K    */
145   b->Args({112 * 112,   24,   3 * 3 * 3});
146   b->Args({ 56 *  56,   24,  24 * 1 * 1});
147   b->Args({ 28 *  28,   24,  24 * 1 * 1});
148   b->Args({ 28 *  28,   48,  48 * 1 * 1});
149   b->Args({ 14 *  14,   48,  48 * 1 * 1});
150   b->Args({ 14 *  14,   96,  96 * 1 * 1});
151   b->Args({  7 *   7,   96,  96 * 1 * 1});
152   b->Args({  7 *   7, 1024, 192 * 1 * 1});
153 }
154 
155 // ShuffleNet v2 (1.0X scale)
ShuffleNetV2X10GemmArguments(benchmark::internal::Benchmark * b)156 static void ShuffleNetV2X10GemmArguments(benchmark::internal::Benchmark* b) {
157   b->ArgNames({"M", "N", "K"});
158 
159   /*           M        N         K    */
160   b->Args({112 * 112,   24,   3 * 3 * 3});
161   b->Args({ 56 *  56,   58,  24 * 1 * 1});
162   b->Args({ 28 *  28,   58,  24 * 1 * 1});
163   b->Args({ 28 *  28,   58,  58 * 1 * 1});
164   b->Args({ 14 *  14,  116, 116 * 1 * 1});
165   b->Args({ 14 *  14,  116, 116 * 1 * 1});
166   b->Args({ 14 *  14,  232, 232 * 1 * 1});
167   b->Args({  7 *   7,  232, 232 * 1 * 1});
168   b->Args({  7 *   7, 1024, 464 * 1 * 1});
169 }
170 
171 // ShuffleNet v2 (1.5X scale)
ShuffleNetV2X15GemmArguments(benchmark::internal::Benchmark * b)172 static void ShuffleNetV2X15GemmArguments(benchmark::internal::Benchmark* b) {
173   b->ArgNames({"M", "N", "K"});
174 
175   /*           M        N         K    */
176   b->Args({112 * 112,   24,   3 * 3 * 3});
177   b->Args({ 56 *  56,   88,  24 * 1 * 1});
178   b->Args({ 28 *  28,   88,  24 * 1 * 1});
179   b->Args({ 28 *  28,   88,  88 * 1 * 1});
180   b->Args({ 28 *  28,  176, 176 * 1 * 1});
181   b->Args({ 14 *  14,  176, 176 * 1 * 1});
182   b->Args({ 14 *  14,  352, 352 * 1 * 1});
183   b->Args({  7 *   7,  352, 352 * 1 * 1});
184   b->Args({  7 *   7, 1024, 704 * 1 * 1});
185 }
186 
187 // ShuffleNet v2 (2.0X scale)
ShuffleNetV2X20GemmArguments(benchmark::internal::Benchmark * b)188 static void ShuffleNetV2X20GemmArguments(benchmark::internal::Benchmark* b) {
189   b->ArgNames({"M", "N", "K"});
190 
191   /*           M        N         K    */
192   b->Args({112 * 112,   24,   3 * 3 * 3});
193   b->Args({ 56 *  56,  122,  24 * 1 * 1});
194   b->Args({ 28 *  28,  122,  24 * 1 * 1});
195   b->Args({ 28 *  28,  122, 122 * 1 * 1});
196   b->Args({ 28 *  28,  244, 244 * 1 * 1});
197   b->Args({ 14 *  14,  244, 244 * 1 * 1});
198   b->Args({ 14 *  14,  488, 488 * 1 * 1});
199   b->Args({  7 *   7,  488, 488 * 1 * 1});
200   b->Args({  7 *   7, 2048, 976 * 1 * 1});
201 }
202 
MobileNetV1GemmArguments(benchmark::internal::Benchmark * b)203 static void MobileNetV1GemmArguments(benchmark::internal::Benchmark* b) {
204   b->ArgNames({"M", "N", "K"});
205 
206   /*           M        N          K    */
207   b->Args({112 * 112,   32,    3 * 3 * 3});
208   b->Args({112 * 112,   64,   32 * 1 * 1});
209   b->Args({ 56 *  56,  128,   64 * 1 * 1});
210   b->Args({ 56 *  56,  128,  128 * 1 * 1});
211   b->Args({ 28 *  28,  256,  128 * 1 * 1});
212   b->Args({ 28 *  28,  256,  256 * 1 * 1});
213   b->Args({ 14 *  14,  512,  256 * 1 * 1});
214   b->Args({ 14 *  14,  512,  512 * 1 * 1});
215   b->Args({  7 *   7, 1024,  512 * 1 * 1});
216   b->Args({  7 *   7, 1024, 1024 * 1 * 1});
217 }
218 
MobileNetV2GemmArguments(benchmark::internal::Benchmark * b)219 static void MobileNetV2GemmArguments(benchmark::internal::Benchmark* b) {
220   b->ArgNames({"M", "N", "K"});
221 
222   /*********** Initial Stage ************/
223   /*           M        N          K    */
224   b->Args({112 * 112,   32,    3 * 3 * 3});
225   /************ Bottleneck 1 ************/
226   /*           M        N          K    */
227   b->Args({112 * 112,   16,   32 * 1 * 1});
228   /************ Bottleneck 2 ************/
229   /*           M        N          K    */
230   b->Args({112 * 112,   96,   16 * 1 * 1});
231   b->Args({ 56 *  56,   24,   96 * 1 * 1});
232   b->Args({ 56 *  56,  144,   24 * 1 * 1});
233   b->Args({ 56 *  56,   24,  144 * 1 * 1});
234   /************ Bottleneck 3 ************/
235   /*           M        N          K    */
236   b->Args({ 28 *  28,   32,  144 * 1 * 1});
237   b->Args({ 28 *  28,  192,   32 * 1 * 1});
238   b->Args({ 28 *  28,   32,  192 * 1 * 1});
239   /************ Bottleneck 4 ************/
240   /*           M        N          K    */
241   b->Args({ 14 *  14,   64,  192 * 1 * 1});
242   b->Args({ 14 *  14,  384,   64 * 1 * 1});
243   b->Args({ 14 *  14,   64,  384 * 1 * 1});
244   /************ Bottleneck 5 ************/
245   /*           M        N          K    */
246   b->Args({ 14 *  14,   96,  384 * 1 * 1});
247   b->Args({ 14 *  14,  576,   96 * 1 * 1});
248   b->Args({ 14 *  14,   96,  576 * 1 * 1});
249   /************ Bottleneck 6 ************/
250   /*           M        N          K    */
251   b->Args({  7 *   7,  160,  576 * 1 * 1});
252   b->Args({  7 *   7,  960,  160 * 1 * 1});
253   b->Args({  7 *   7,  160,  960 * 1 * 1});
254   /************ Bottleneck 7 ************/
255   /*           M        N          K    */
256   b->Args({  7 *   7,  320,  960 * 1 * 1});
257   /********* Pre-pooling Conv2D *********/
258   /*           M        N          K    */
259   b->Args({  7 *   7, 1280,  320 * 1 * 1});
260   /******** Post-pooling Conv2D *********/
261   /*           M        N          K    */
262   b->Args({  1 *   1, 1000, 1280 * 1 * 1});
263 }
264 
MobileNetV3SmallGemmArguments(benchmark::internal::Benchmark * b)265 static void MobileNetV3SmallGemmArguments(benchmark::internal::Benchmark* b) {
266   b->ArgNames({"M", "N", "K"});
267 
268   /************ Initial Stage ************/
269   /*           M        N          K     */
270   b->Args({112 * 112,   16,    3 * 3 * 3});
271   /************* Bottleneck 1 ************/
272   /*           M        N          K     */
273   b->Args({  1 *   1,    8,   16 * 1 * 1});
274   b->Args({  1 *   1,   16,    8 * 1 * 1});
275   b->Args({ 56 *  56,   16,   16 * 1 * 1});
276   /************* Bottleneck 2 ************/
277   /*           M        N          K     */
278   b->Args({ 56 *  56,   72,   16 * 1 * 1});
279   b->Args({ 28 *  28,   24,   72 * 1 * 1});
280   /************* Bottleneck 3 ************/
281   /*           M        N          K     */
282   b->Args({ 28 *  28,   88,   24 * 1 * 1});
283   b->Args({ 28 *  28,   24,   88 * 1 * 1});
284   /************* Bottleneck 4 ************/
285   /*           M        N          K     */
286   b->Args({ 28 *  28,   96,   24 * 1 * 1});
287   b->Args({  1 *   1,   24,   96 * 1 * 1});
288   b->Args({  1 *   1,   96,   24 * 1 * 1});
289   b->Args({ 14 *  14,   40,   96 * 1 * 1});
290   /************* Bottleneck 5 ************/
291   /*           M        N          K     */
292   b->Args({ 14 *  14,  240,   40 * 1 * 1});
293   b->Args({  1 *   1,   64,  240 * 1 * 1});
294   b->Args({  1 *   1,  240,   64 * 1 * 1});
295   b->Args({ 14 *  14,   40,  240 * 1 * 1});
296   /************* Bottleneck 6 ************/
297   /*           M        N          K     */
298 //b->Args({ 14 *  14,  240,   40 * 1 * 1});
299 //b->Args({  1 *   1,   64,  240 * 1 * 1});
300 //b->Args({  1 *   1,  240,   64 * 1 * 1});
301 //b->Args({ 14 *  14,   40,  240 * 1 * 1});
302   /************* Bottleneck 7 ************/
303   /*           M        N          K     */
304   b->Args({ 14 *  14,  120,   40 * 1 * 1});
305   b->Args({  1 *   1,   32,  120 * 1 * 1});
306   b->Args({  1 *   1,  120,   32 * 1 * 1});
307   b->Args({ 14 *  14,   48,  120 * 1 * 1});
308   /************* Bottleneck 8 ************/
309   /*           M        N          K     */
310   b->Args({ 14 *  14,  144,   48 * 1 * 1});
311   b->Args({  1 *   1,   40,  144 * 1 * 1});
312   b->Args({  1 *   1,  144,   40 * 1 * 1});
313   b->Args({ 14 *  14,   48,  144 * 1 * 1});
314   /************* Bottleneck 9 ************/
315   /*           M        N          K     */
316   b->Args({ 14 *  14,  288,   48 * 1 * 1});
317   b->Args({  1 *   1,   72,  288 * 1 * 1});
318   b->Args({  1 *   1,  288,   72 * 1 * 1});
319   b->Args({  7 *   7,   96,  288 * 1 * 1});
320   /************ Bottleneck 10 ************/
321   /*           M        N          K     */
322   b->Args({  7 *   7,  576,   96 * 1 * 1});
323   b->Args({  1 *   1,  144,  576 * 1 * 1});
324   b->Args({  1 *   1,  576,  144 * 1 * 1});
325   b->Args({  7 *   7,   96,  576 * 1 * 1});
326   /************ Bottleneck 11 ************/
327   /*           M        N          K     */
328 //b->Args({  7 *   7,  576,   96 * 1 * 1});
329 //b->Args({  1 *   1,  144,  576 * 1 * 1});
330 //b->Args({  1 *   1,  576,  144 * 1 * 1});
331 //b->Args({  7 *   7,   96,  576 * 1 * 1});
332   /************* Last Stage  *************/
333   /*           M        N          K     */
334 //b->Args({  7 *   7,  576,   96 * 1 * 1});
335   b->Args({  1 *   1, 1024,  576 * 1 * 1});
336   b->Args({  1 *   1, 1001, 1024 * 1 * 1});
337 }
338 
MobileNetV3LargeGemmArguments(benchmark::internal::Benchmark * b)339 static void MobileNetV3LargeGemmArguments(benchmark::internal::Benchmark* b) {
340   b->ArgNames({"M", "N", "K"});
341 
342   /************ Initial Stage ************/
343   /*           M        N          K     */
344   b->Args({112 * 112,   16,    3 * 3 * 3});
345   /************* Bottleneck 1 ************/
346   /*           M        N          K     */
347   b->Args({112 * 112,   16,   16 * 1 * 1});
348   /************* Bottleneck 2 ************/
349   /*           M        N          K     */
350   b->Args({112 * 112,   64,   16 * 1 * 1});
351   b->Args({ 56 *  56,   24,   64 * 1 * 1});
352   /************* Bottleneck 3 ************/
353   /*           M        N          K     */
354   b->Args({ 56 *  56,   72,   24 * 1 * 1});
355   b->Args({ 56 *  56,   24,   72 * 1 * 1});
356   /************* Bottleneck 4 ************/
357   /*           M        N          K     */
358 //b->Args({ 56 *  56,   72,   24 * 1 * 1});
359   b->Args({  1 *   1,   24,   72 * 1 * 1});
360   b->Args({  1 *   1,   72,   24 * 1 * 1});
361   b->Args({ 28 *  28,   40,   72 * 1 * 1});
362   /************* Bottleneck 5 ************/
363   /*           M        N          K     */
364   b->Args({ 28 *  28,  120,   40 * 1 * 1});
365   b->Args({  1 *   1,   32,  120 * 1 * 1});
366   b->Args({  1 *   1,  120,   32 * 1 * 1});
367   b->Args({ 28 *  28,   40,  120 * 1 * 1});
368   /************* Bottleneck 6 ************/
369   /*           M        N          K     */
370 //b->Args({ 28 *  28,  120,   40 * 1 * 1});
371 //b->Args({  1 *   1,   32,  120 * 1 * 1});
372 //b->Args({  1 *   1,  120,   32 * 1 * 1});
373 //b->Args({ 28 *  28,   40,  120 * 1 * 1});
374   /************* Bottleneck 7 ************/
375   /*           M        N          K     */
376   b->Args({ 28 *  28,  240,   40 * 1 * 1});
377   b->Args({ 14 *  14,   80,  240 * 1 * 1});
378   /************* Bottleneck 8 ************/
379   /*           M        N          K     */
380   b->Args({ 14 *  14,  200,   80 * 1 * 1});
381   b->Args({ 14 *  14,   80,  200 * 1 * 1});
382   /************* Bottleneck 9 ************/
383   /*           M        N          K     */
384   b->Args({ 14 *  14,  184,   80 * 1 * 1});
385   b->Args({ 14 *  14,   80,  184 * 1 * 1});
386   /************ Bottleneck 10 ************/
387   /*           M        N          K     */
388   b->Args({ 14 *  14,  184,   80 * 1 * 1});
389   b->Args({ 14 *  14,   80,  184 * 1 * 1});
390   /************ Bottleneck 11 ************/
391   /*           M        N          K     */
392   b->Args({ 14 *  14,  480,   80 * 1 * 1});
393   b->Args({  1 *   1,  120,  480 * 1 * 1});
394   b->Args({  1 *   1,  480,  120 * 1 * 1});
395   b->Args({ 14 *  14,  112,  480 * 1 * 1});
396   /************ Bottleneck 12 ************/
397   /*           M        N          K     */
398   b->Args({ 14 *  14,  672,  112 * 1 * 1});
399   b->Args({  1 *   1,  168,  672 * 1 * 1});
400   b->Args({  1 *   1,  672,  168 * 1 * 1});
401   b->Args({ 14 *  14,  112,  672 * 1 * 1});
402   /************ Bottleneck 13 ************/
403   /*           M        N          K     */
404 //b->Args({ 14 *  14,  672,  112 * 1 * 1});
405 //b->Args({  1 *   1,  168,  672 * 1 * 1});
406 //b->Args({  1 *   1,  672,  168 * 1 * 1});
407   b->Args({  7 *   7,  160,  672 * 1 * 1});
408   /************ Bottleneck 14 ************/
409   /*           M        N          K     */
410   b->Args({  7 *   7,  960,  160 * 1 * 1});
411   b->Args({  1 *   1,  240,  960 * 1 * 1});
412   b->Args({  1 *   1,  960,  240 * 1 * 1});
413   b->Args({  7 *   7,  160,  960 * 1 * 1});
414   /************ Bottleneck 15 ************/
415   /*           M        N          K     */
416 //b->Args({  7 *   7,  960,  160 * 1 * 1});
417 //b->Args({  1 *   1,  240,  960 * 1 * 1});
418 //b->Args({  1 *   1,  960,  240 * 1 * 1});
419 //b->Args({  7 *   7,  160,  960 * 1 * 1});
420   /************* Last Stage  *************/
421   /*           M        N          K     */
422 //b->Args({  7 *   7,  960,  160 * 1 * 1});
423   b->Args({  1 *   1, 1280,  960 * 1 * 1});
424   b->Args({  1 *   1, 1001, 1280 * 1 * 1});
425 }
426 
427 // SqueezeNet 1.0
SqueezeNetV10GemmArguments(benchmark::internal::Benchmark * b)428 static void SqueezeNetV10GemmArguments(benchmark::internal::Benchmark* b) {
429   b->ArgNames({"M", "N", "K"});
430 
431   /************** Conv 1 ***************/
432   /*           M        N         K    */
433   b->Args({111 * 111,   96,   3 * 7 * 7});
434   /************** Fire 2 ***************/
435   /*           M        N         K    */
436   b->Args({ 55 *  55,   16,  96 * 1 * 1});
437   b->Args({ 55 *  55,   64,  16 * 1 * 1});
438   b->Args({ 55 *  55,   64,  16 * 3 * 3});
439   /************** Fire 3 ***************/
440   /*           M        N         K    */
441   b->Args({ 55 *  55,   16, 128 * 1 * 1});
442   /************** Fire 4 ***************/
443   /*           M        N         K    */
444   b->Args({ 55 *  55,   32, 128 * 1 * 1});
445   b->Args({ 55 *  55,  128,  32 * 1 * 1});
446   b->Args({ 55 *  55,  128,  32 * 3 * 3});
447   /************** Fire 5 ***************/
448   /*           M        N         K    */
449   b->Args({ 27 *  27,   32, 256 * 1 * 1});
450   b->Args({ 27 *  27,  128,  32 * 1 * 1});
451   b->Args({ 27 *  27,  128,  32 * 3 * 3});
452   /************** Fire 6 ***************/
453   /*           M        N         K    */
454   b->Args({ 27 *  27,   48, 256 * 1 * 1});
455   b->Args({ 27 *  27,  192,  48 * 1 * 1});
456   b->Args({ 27 *  27,  192,  48 * 3 * 3});
457   /************** Fire 7 ***************/
458   /*           M        N         K    */
459   b->Args({ 27 *  27,   48, 384 * 1 * 1});
460   /************** Fire 8 ***************/
461   /*           M        N         K    */
462   b->Args({ 27 *  27,   64, 384 * 1 * 1});
463   b->Args({ 27 *  27,  256,  64 * 1 * 1});
464   b->Args({ 27 *  27,  256,  64 * 3 * 3});
465   /************** Fire 9 ***************/
466   /*           M        N         K    */
467   b->Args({ 13 *  13,   64, 512 * 1 * 1});
468   b->Args({ 13 *  13,  256,  64 * 1 * 1});
469   b->Args({ 13 *  13,  256,  64 * 3 * 3});
470   /************** Conv 10 **************/
471   /*           M        N         K    */
472   b->Args({ 13 *  13, 1000, 512 * 1 * 1});
473 }
474 
475 // SqueezeNet 1.1
SqueezeNetV11GemmArguments(benchmark::internal::Benchmark * b)476 static void SqueezeNetV11GemmArguments(benchmark::internal::Benchmark* b) {
477   b->ArgNames({"M", "N", "K"});
478 
479   /************** Conv 1 ***************/
480   /*           M        N         K    */
481   b->Args({111 * 111,   64,   3 * 3 * 3});
482   /************** Fire 2 ***************/
483   /*           M        N         K    */
484   b->Args({ 55 *  55,   16,  64 * 1 * 1});
485   b->Args({ 55 *  55,   64,  16 * 1 * 1});
486   b->Args({ 55 *  55,   64,  16 * 3 * 3});
487   /************** Fire 3 ***************/
488   /*           M        N         K    */
489   b->Args({ 55 *  55,   16, 128 * 1 * 1});
490   /************** Fire 4 ***************/
491   /*           M        N         K    */
492   b->Args({ 27 *  27,   32, 128 * 1 * 1});
493   b->Args({ 27 *  27,  128,  32 * 1 * 1});
494   b->Args({ 27 *  27,  128,  32 * 3 * 3});
495   /************** Fire 5 ***************/
496   /*           M        N         K    */
497   b->Args({ 27 *  27,   32, 256 * 1 * 1});
498   /************** Fire 6 ***************/
499   /*           M        N         K    */
500   b->Args({ 13 *  13,   48, 256 * 1 * 1});
501   b->Args({ 13 *  13,  192,  48 * 1 * 1});
502   b->Args({ 13 *  13,  192,  48 * 3 * 3});
503   /************** Fire 7 ***************/
504   /*           M        N         K    */
505   b->Args({ 13 *  13,   48, 384 * 1 * 1});
506   /************** Fire 8 ***************/
507   /*           M        N         K    */
508   b->Args({ 13 *  13,   64, 384 * 1 * 1});
509   b->Args({ 13 *  13,  256,  64 * 1 * 1});
510   b->Args({ 13 *  13,  256,  64 * 3 * 3});
511   /************** Fire 9 ***************/
512   /*           M        N         K    */
513   b->Args({ 13 *  13,   64, 512 * 1 * 1});
514   /************** Conv 10 **************/
515   /*           M        N         K    */
516   b->Args({ 13 *  13, 1000, 512 * 1 * 1});
517 }
518 
InceptionV3GemmArguments(benchmark::internal::Benchmark * b)519 static void InceptionV3GemmArguments(benchmark::internal::Benchmark* b) {
520   /*           M        N          K   */
521   b->Args({150 * 150,   32,    3 * 3 * 3});
522   b->Args({149 * 149,   32,   32 * 3 * 3});
523   b->Args({149 * 149,   64,   32 * 3 * 3});
524   b->Args({ 75 *  75,   80,   64 * 1 * 1});
525   b->Args({ 73 *  73,  192,   80 * 3 * 3});
526   b->Args({ 37 *  37,   64,  192 * 1 * 1});
527   b->Args({ 37 *  37,   48,  192 * 1 * 1});
528   b->Args({ 37 *  37,   64,   48 * 5 * 5});
529   b->Args({ 37 *  37,   96,   64 * 3 * 3});
530   b->Args({ 37 *  37,   96,   96 * 3 * 3});
531   b->Args({ 37 *  37,   32,  192 * 1 * 1});
532   b->Args({ 37 *  37,   64,  256 * 1 * 1});
533   b->Args({ 37 *  37,   48,  256 * 1 * 1});
534   b->Args({ 37 *  37,   64,  288 * 1 * 1});
535   b->Args({ 37 *  37,   48,  288 * 1 * 1});
536   b->Args({ 18 *  18,  384,  288 * 3 * 3});
537   b->Args({ 18 *  18,   96,   96 * 3 * 3});
538   b->Args({ 19 *  19,  192,  768 * 1 * 1});
539   b->Args({ 19 *  19,  128,  768 * 1 * 1});
540   b->Args({ 19 *  19,  128,  128 * 1 * 7});
541   b->Args({ 19 *  19,  192,  128 * 7 * 1});
542   b->Args({ 19 *  19,  128,  128 * 7 * 1});
543   b->Args({ 19 *  19,  192,  128 * 1 * 7});
544   b->Args({ 19 *  19,  160,  768 * 1 * 1});
545   b->Args({ 19 *  19,  160,  160 * 1 * 7});
546   b->Args({ 19 *  19,  192,  160 * 7 * 1});
547   b->Args({ 19 *  19,  160,  160 * 7 * 1});
548   b->Args({ 19 *  19,  192,  160 * 1 * 7});
549   b->Args({ 19 *  19,  192,  192 * 1 * 7});
550   b->Args({ 19 *  19,  192,  192 * 7 * 1});
551   b->Args({  9 *   9,  320,  192 * 3 * 3});
552   b->Args({  9 *   9,  192,  192 * 3 * 3});
553   b->Args({ 10 *  10,  320, 1280 * 1 * 1});
554   b->Args({ 10 *  10,  384, 1280 * 1 * 1});
555   b->Args({ 10 *  10,  384,  384 * 1 * 3});
556   b->Args({ 10 *  10,  384,  384 * 3 * 1});
557   b->Args({ 10 *  10,  448, 1280 * 1 * 1});
558   b->Args({ 10 *  10,  384,  448 * 3 * 3});
559   b->Args({ 10 *  10,  192, 1280 * 1 * 1});
560   b->Args({ 10 *  10,  320, 2048 * 1 * 1});
561   b->Args({ 10 *  10,  384, 2048 * 1 * 1});
562   b->Args({ 10 *  10,  448, 2048 * 1 * 1});
563   b->Args({ 10 *  10,  192, 2048 * 1 * 1});
564   b->Args({  3 *   3, 1001, 2048 * 1 * 1});
565 }
566 
ResNet18GemmArguments(benchmark::internal::Benchmark * b)567 static void ResNet18GemmArguments(benchmark::internal::Benchmark* b) {
568   b->ArgNames({"M", "N", "K"});
569 
570   /*           M       N         K    */
571   b->Args({112 * 112,  64,   3 * 7 * 7});
572   b->Args({ 56 *  56,  64,  64 * 3 * 3});
573   b->Args({ 28 *  28, 128,  64 * 3 * 3});
574   b->Args({ 28 *  28, 128, 128 * 3 * 3});
575   b->Args({ 28 *  28, 128,  64 * 1 * 1});
576   b->Args({ 14 *  14, 256, 128 * 3 * 3});
577   b->Args({ 14 *  14, 256, 256 * 3 * 3});
578   b->Args({ 14 *  14, 256, 128 * 1 * 1});
579   b->Args({  7 *   7, 512, 256 * 3 * 3});
580   b->Args({  7 *   7, 512, 512 * 3 * 3});
581   b->Args({  7 *   7, 512, 256 * 1 * 1});
582 }
583 
ResNet50GemmArguments(benchmark::internal::Benchmark * b)584 static void ResNet50GemmArguments(benchmark::internal::Benchmark* b) {
585   b->ArgNames({"M", "N", "K"});
586 
587   /*************** Conv 1 ***************/
588   /*           M        N          K    */
589   b->Args({112 * 112,   64,    3 * 7 * 7});
590   /************** Conv 2.X **************/
591   /*           M        N          K    */
592   b->Args({ 56 *  56,   64,   64 * 1 * 1});
593   b->Args({ 56 *  56,   64,   64 * 3 * 3});
594   b->Args({ 56 *  56,  256,   64 * 1 * 1});
595   b->Args({ 56 *  56,   64,  256 * 1 * 1});
596   /************** Conv 3.X **************/
597   /*           M        N          K    */
598   b->Args({ 56 *  56,  128,  256 * 1 * 1});
599   b->Args({ 28 *  28,  128,  128 * 3 * 3});
600   b->Args({ 28 *  28,  512,  128 * 1 * 1});
601   b->Args({ 28 *  28,  512,  256 * 1 * 1});
602   b->Args({ 28 *  28,  128,  512 * 1 * 1});
603   /************** Conv 4.X **************/
604   /*           M        N          K    */
605   b->Args({ 28 *  28,  256,  512 * 1 * 1});
606   b->Args({ 14 *  14,  256,  256 * 3 * 3});
607   b->Args({ 14 *  14, 1024,  256 * 1 * 1});
608   b->Args({ 14 *  14, 1024,  512 * 1 * 1});
609   b->Args({ 14 *  14,  256, 1024 * 1 * 1});
610   /************** Conv 5.X **************/
611   /*           M        N          K    */
612   b->Args({ 14 *  14,  512, 1024 * 1 * 1});
613   b->Args({  7 *   7,  512,  512 * 3 * 3});
614   b->Args({  7 *   7, 2048,  512 * 1 * 1});
615   b->Args({  7 *   7, 2048, 1024 * 1 * 1});
616   b->Args({  7 *   7,  512, 2048 * 1 * 1});
617 }
618 
VGGGemmArguments(benchmark::internal::Benchmark * b)619 static void VGGGemmArguments(benchmark::internal::Benchmark* b) {
620   b->ArgNames({"M", "N", "K"});
621 
622   /************** Conv 1.1 *************/
623   /*           M        N        K     */
624   b->Args({224 * 224,  64,   3 * 3 * 3});
625   /************** Conv 1.2 *************/
626   /*           M        N        K     */
627   b->Args({224 * 224,  64,  64 * 3 * 3});
628   /************** Conv 2.1 *************/
629   /*           M        N        K     */
630   b->Args({112 * 112, 128,  64 * 3 * 3});
631   /************** Conv 2.2 *************/
632   /*           M        N        K     */
633   b->Args({112 * 112, 128, 128 * 3 * 3});
634   /************** Conv 3.1 *************/
635   /*           M        N        K     */
636   b->Args({ 56 *  56, 256, 128 * 3 * 3});
637   /************** Conv 3.3 *************/
638   /*           M        N        K     */
639   b->Args({ 56 *  56, 256, 256 * 1 * 1});
640   /************** Conv 4.1 *************/
641   /*           M        N        K     */
642   b->Args({ 28 *  28, 512, 256 * 3 * 3});
643   /************** Conv 4.2 *************/
644   /*           M        N        K     */
645   b->Args({ 28 *  28, 512, 512 * 3 * 3});
646   /************** Conv 4.3 *************/
647   /*           M        N        K     */
648   b->Args({ 28 *  28, 512, 512 * 1 * 1});
649   /************** Conv 5.X *************/
650   /*           M        N        K     */
651   b->Args({ 14 *  14, 512, 512 * 3 * 3});
652   /************** Conv 5.3 *************/
653   /*           M        N        K     */
654   b->Args({ 14 *  14, 512, 512 * 1 * 1});
655 }
656 
657 // SRCNN (9-1-5)
SRCNN915GemmArguments(benchmark::internal::Benchmark * b)658 static void SRCNN915GemmArguments(benchmark::internal::Benchmark* b) {
659   b->ArgNames({"M", "N", "K"});
660 
661   /*           M       N       K    */
662   b->Args({376 * 376, 64,  1 * 9 * 9});
663   b->Args({376 * 376, 32, 64 * 1 * 1});
664   b->Args({372 * 372,  1, 32 * 5 * 5});
665 }
666 
667 // SRCNN (9-3-5)
SRCNN935GemmArguments(benchmark::internal::Benchmark * b)668 static void SRCNN935GemmArguments(benchmark::internal::Benchmark* b) {
669   b->ArgNames({"M", "N", "K"});
670 
671   /*           M       N       K    */
672   b->Args({376 * 376, 64,  1 * 9 * 9});
673   b->Args({374 * 374, 32, 64 * 3 * 3});
674   b->Args({370 * 370,  1, 32 * 5 * 5});
675 }
676 
677 // SRCNN (9-5-5)
SRCNN955GemmArguments(benchmark::internal::Benchmark * b)678 static void SRCNN955GemmArguments(benchmark::internal::Benchmark* b) {
679   b->ArgNames({"M", "N", "K"});
680 
681   /*           M       N       K    */
682   b->Args({376 * 376, 64,  1 * 9 * 9});
683   b->Args({372 * 372, 32, 64 * 5 * 5});
684   b->Args({368 * 368,  1, 32 * 5 * 5});
685 }
686