• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <unistd.h>
16 #ifdef __APPLE__
17 #include <sys/time.h>
18 #endif
19 
20 #include <cstdint>
21 #include <cstdlib>
22 #include <ctime>
23 #include <iomanip>
24 #include <iostream>
25 #include <map>
26 #include <vector>
27 
28 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
29 #include "test.h"
30 
31 #if defined(__arm__) && !defined(GEMMLOWP_NEON)
32 #warning "Building without NEON support on ARM, check your compiler setup!"
33 #endif
34 
time()35 double time() {
36 #ifdef __APPLE__
37   timeval t;
38   gettimeofday(&t, nullptr);
39   return t.tv_sec + 1e-6 * t.tv_usec;
40 #else
41   timespec t;
42   clock_gettime(CLOCK_REALTIME, &t);
43   return t.tv_sec + 1e-9 * t.tv_nsec;
44 #endif
45 }
46 
47 const std::int32_t MIN_WORKING_SET_SIZE = 2 * 1024 * 1024;
48 const double MIN_OPS = 1000.0 * 1000000.0;
49 
50 struct WorkingSet {
WorkingSetWorkingSet51   WorkingSet() : lhs(nullptr), rhs(nullptr), result(nullptr) {}
52 
initWorkingSet53   void init(std::int32_t n, std::int32_t m, std::int32_t k) {
54     lhs = new std::uint8_t[n * k];
55     rhs = new std::uint8_t[k * m];
56     result = new std::uint8_t[m * n];
57   }
58 
59   std::uint8_t* lhs;
60   std::uint8_t* rhs;
61   std::uint8_t* result;
62 };
63 
64 struct Shape {
65   std::int32_t n;
66   std::int32_t m;
67   std::int32_t k;
68 
69   std::int32_t repetitions;
70   std::int32_t current_set;
71   std::vector<WorkingSet> working_sets;
72 
ShapeShape73   Shape(std::int32_t n, std::int32_t m, std::int32_t k)
74       : n(n), m(m), k(k), repetitions(1), current_set(0), working_sets() {}
75 
initShape76   void init() {
77     const std::int32_t size = n * k + k * m + n * m;
78     const std::int32_t count = MIN_WORKING_SET_SIZE / size + 1;
79     const double ops = static_cast<double>(n) * static_cast<double>(m) *
80                        static_cast<double>(k);
81     for (int i = 0; i < count; ++i) {
82       working_sets.push_back(WorkingSet());
83       working_sets.back().init(n, m, k);
84     }
85     current_set = 0;
86     repetitions = MIN_OPS / ops + 20;
87   }
88 
working_setShape89   WorkingSet& working_set() { return working_sets[current_set]; }
90 
next_working_setShape91   void next_working_set() {
92     current_set = (current_set + 1) % working_sets.size();
93   }
94 };
95 
run_gemm(std::int32_t n,std::int32_t m,std::int32_t k,std::uint8_t * lhs,std::uint8_t * rhs,std::uint8_t * result)96 double run_gemm(std::int32_t n, std::int32_t m, std::int32_t k,
97                 std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* result) {
98   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
99       true, false, false, m, n, k, rhs, -100, k, lhs, -100, k, result, 10000,
100       10, 3, m, gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
101   return static_cast<double>(n * m * k * 2);
102 }
103 
run_gemms(std::vector<Shape> * shapes)104 double run_gemms(std::vector<Shape>* shapes) {
105   double ops = 0.0;
106   for (auto& shape : *shapes) {
107     ops += run_gemm(shape.n, shape.m, shape.k, shape.working_set().lhs,
108                     shape.working_set().rhs, shape.working_set().result);
109   }
110   return ops;
111 }
112 
print_summary(std::vector<double> * times,bool full)113 void print_summary(std::vector<double>* times, bool full) {
114   std::sort(times->begin(), times->end());
115 
116   double sum_times = 0;
117   double sum_times_trimmed = 0;
118   int count_times_trimmed = 0;
119   const float trim_ratio = 0.25;
120   const size_t count_trimmed = times->size() * trim_ratio;
121   double sum_times_best = 0;
122   int count_times_best = 0;
123   const float best_ratio = 0.1;
124   const size_t count_best = times->size() * best_ratio;
125 
126   for (size_t i = 0; i < times->size(); i++) {
127     sum_times += (*times)[i];
128     if (i >= count_trimmed && i < times->size() - count_trimmed) {
129       sum_times_trimmed += (*times)[i];
130       count_times_trimmed++;
131     }
132     if (i < count_best) {
133       sum_times_best += (*times)[i];
134       count_times_best++;
135     }
136   }
137 
138   const double min_latency = times->front();
139   const double max_latency = times->back();
140   const double mean_latency = sum_times / times->size();
141   const double trimmed_mean_latency = sum_times_trimmed / count_times_trimmed;
142   const double best_mean_latency = sum_times_best / count_times_best;
143 
144   if (full) {
145     std::cout << "Graph latency (over " << times->size()
146               << " iterations):" << std::endl;
147     std::cout << "  Best:             " << min_latency << "s" << std::endl;
148     std::cout << "  Worst:            " << max_latency << "s" << std::endl;
149     std::cout << "  Mean:             " << mean_latency << "s" << std::endl;
150     std::cout << "  " << 100 * trim_ratio
151               << "% trimmed mean: " << trimmed_mean_latency << "s" << std::endl;
152     std::cout << "  Mean of " << 100 * best_ratio
153               << "% best: " << best_mean_latency << "s" << std::endl;
154   } else {
155     std::cout << (mean_latency * 1000.0) << std::endl;
156   }
157 }
158 
time_all(std::vector<Shape> * shapes,std::int32_t repetitions,double max_time)159 void time_all(std::vector<Shape>* shapes, std::int32_t repetitions,
160               double max_time) {
161   std::vector<double> times;
162   double ops = 0.0;
163   double sum_time = 0.0;
164 
165   while (sum_time < max_time) {
166     double start = time();
167 
168     for (int i = 0; i < repetitions; ++i) {
169       ops += run_gemms(shapes);
170     }
171     double delta_time = (time() - start);
172     times.push_back(delta_time / repetitions);
173     sum_time += delta_time;
174   }
175 
176   print_summary(&times, true);
177 }
178 
time_one(Shape * shape,double max_time)179 void time_one(Shape* shape, double max_time) {
180   std::vector<double> times;
181   double ops = 0.0;
182   double sum_time = 0.0;
183 
184   std::cout << std::setprecision(6) << std::fixed << shape->n << ", "
185             << shape->m << ", " << shape->k << ", " << std::flush;
186 
187   while (sum_time < max_time) {
188     double start = time();
189 
190     for (int i = 0; i < shape->repetitions; ++i) {
191       ops += run_gemm(shape->n, shape->m, shape->k, shape->working_set().lhs,
192                       shape->working_set().rhs, shape->working_set().result);
193       shape->next_working_set();
194     }
195     double delta_time = (time() - start);
196     times.push_back(delta_time / shape->repetitions);
197     sum_time += delta_time;
198   }
199 
200   print_summary(&times, false);
201 }
202 
main()203 int main() {
204   std::vector<Shape> googlenet_gemms;
205   googlenet_gemms.push_back(Shape(12544, 64, 147));
206   googlenet_gemms.push_back(Shape(3136, 64, 64));
207   googlenet_gemms.push_back(Shape(3136, 192, 576));
208   googlenet_gemms.push_back(Shape(784, 64, 192));
209   googlenet_gemms.push_back(Shape(784, 96, 192));
210   googlenet_gemms.push_back(Shape(784, 128, 864));
211   googlenet_gemms.push_back(Shape(784, 16, 192));
212   googlenet_gemms.push_back(Shape(784, 32, 400));
213   googlenet_gemms.push_back(Shape(784, 32, 192));
214   googlenet_gemms.push_back(Shape(784, 128, 256));
215   googlenet_gemms.push_back(Shape(784, 128, 256));
216   googlenet_gemms.push_back(Shape(784, 192, 1152));
217   googlenet_gemms.push_back(Shape(784, 32, 256));
218   googlenet_gemms.push_back(Shape(784, 96, 800));
219   googlenet_gemms.push_back(Shape(784, 64, 256));
220   googlenet_gemms.push_back(Shape(196, 192, 480));
221   googlenet_gemms.push_back(Shape(196, 96, 480));
222   googlenet_gemms.push_back(Shape(196, 204, 864));
223   googlenet_gemms.push_back(Shape(196, 16, 480));
224   googlenet_gemms.push_back(Shape(196, 48, 400));
225   googlenet_gemms.push_back(Shape(196, 64, 480));
226   googlenet_gemms.push_back(Shape(196, 160, 508));
227   googlenet_gemms.push_back(Shape(196, 112, 508));
228   googlenet_gemms.push_back(Shape(196, 224, 1008));
229   googlenet_gemms.push_back(Shape(196, 24, 508));
230   googlenet_gemms.push_back(Shape(196, 64, 600));
231   googlenet_gemms.push_back(Shape(196, 64, 508));
232   googlenet_gemms.push_back(Shape(196, 128, 512));
233   googlenet_gemms.push_back(Shape(196, 128, 512));
234   googlenet_gemms.push_back(Shape(196, 256, 1152));
235   googlenet_gemms.push_back(Shape(196, 24, 512));
236   googlenet_gemms.push_back(Shape(196, 64, 600));
237   googlenet_gemms.push_back(Shape(196, 64, 512));
238   googlenet_gemms.push_back(Shape(196, 112, 512));
239   googlenet_gemms.push_back(Shape(196, 144, 512));
240   googlenet_gemms.push_back(Shape(196, 288, 1296));
241   googlenet_gemms.push_back(Shape(196, 32, 512));
242   googlenet_gemms.push_back(Shape(196, 64, 800));
243   googlenet_gemms.push_back(Shape(196, 64, 512));
244   googlenet_gemms.push_back(Shape(196, 256, 528));
245   googlenet_gemms.push_back(Shape(196, 160, 528));
246   googlenet_gemms.push_back(Shape(196, 320, 1440));
247   googlenet_gemms.push_back(Shape(196, 32, 528));
248   googlenet_gemms.push_back(Shape(196, 128, 800));
249   googlenet_gemms.push_back(Shape(196, 128, 528));
250   googlenet_gemms.push_back(Shape(49, 256, 832));
251   googlenet_gemms.push_back(Shape(49, 160, 832));
252   googlenet_gemms.push_back(Shape(49, 320, 1440));
253   googlenet_gemms.push_back(Shape(49, 48, 832));
254   googlenet_gemms.push_back(Shape(49, 128, 1200));
255   googlenet_gemms.push_back(Shape(49, 128, 832));
256   googlenet_gemms.push_back(Shape(49, 384, 832));
257   googlenet_gemms.push_back(Shape(49, 192, 832));
258   googlenet_gemms.push_back(Shape(49, 384, 1728));
259   googlenet_gemms.push_back(Shape(49, 48, 832));
260   googlenet_gemms.push_back(Shape(49, 128, 1200));
261   googlenet_gemms.push_back(Shape(49, 128, 832));
262   googlenet_gemms.push_back(Shape(16, 128, 508));
263   googlenet_gemms.push_back(Shape(1, 1024, 2048));
264   googlenet_gemms.push_back(Shape(1, 1008, 1024));
265   googlenet_gemms.push_back(Shape(16, 128, 528));
266   googlenet_gemms.push_back(Shape(1, 1024, 2048));
267   googlenet_gemms.push_back(Shape(1, 1008, 1024));
268   googlenet_gemms.push_back(Shape(1, 1008, 1024));
269 
270   for (auto& shape : googlenet_gemms) {
271     shape.init();
272   }
273 
274   std::vector<Shape> small_gemms;
275   small_gemms.push_back(Shape(29232, 16, 25));
276   small_gemms.push_back(Shape(7308, 6, 400));
277   small_gemms.push_back(Shape(203, 3002, 216));
278 
279   for (auto& shape : small_gemms) {
280     shape.init();
281   }
282 
283   std::vector<Shape> others;
284   others.push_back(Shape(100, 100, 100));
285   others.push_back(Shape(1000, 1000, 1000));
286   others.push_back(Shape(2000, 1000, 1000));
287 
288   for (auto& shape : others) {
289     shape.init();
290   }
291 
292   std::vector<Shape> lstm;
293   lstm.push_back(Shape(1, 500, 320));
294   lstm.push_back(Shape(1, 100, 500));
295   lstm.push_back(Shape(1, 500, 500));
296   lstm.push_back(Shape(1, 500, 100));
297   lstm.push_back(Shape(1, 2000, 100));
298 
299   for (auto& shape : lstm) {
300     shape.init();
301   }
302 
303   gemmlowp::eight_bit_int_gemm::SetMaxNumThreads(4);
304 
305   std::cout << "Warmup run." << std::endl;
306   time_all(&googlenet_gemms, 10, 1.0);
307   time_all(&small_gemms, 50, 1.0);
308 
309   std::cout << "Timing all." << std::endl;
310   time_all(&googlenet_gemms, 10, 10.0);
311   time_all(&small_gemms, 50, 10.0);
312 
313   std::cout << "Timing separate." << std::endl;
314 
315   for (auto& shape : googlenet_gemms) {
316     time_one(&shape, 0.10);
317   }
318 
319   for (auto& shape : small_gemms) {
320     time_one(&shape, 0.10);
321   }
322 
323   for (auto& shape : others) {
324     time_one(&shape, 0.10);
325   }
326 
327   for (auto& shape : lstm) {
328     time_one(&shape, 0.10);
329   }
330 
331   return 0;
332 }
333