• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/ascii.h"
25 
26 namespace tflite {
27 namespace gpu {
28 namespace {
29 
GetGpuVendor(const std::string & gpu_description)30 GpuVendor GetGpuVendor(const std::string& gpu_description) {
31   const std::map<std::string, GpuVendor> kMapping = {
32       {"adreno", GpuVendor::kQualcomm},
33       {"apple", GpuVendor::kApple},
34       {"qualcomm", GpuVendor::kQualcomm},
35       {"mali", GpuVendor::kMali},
36       {"powervr", GpuVendor::kPowerVR},
37       {"advanced micro devices", GpuVendor::kAMD},
38       {"intel", GpuVendor::kIntel},
39       {"nvidia", GpuVendor::kNvidia},
40       {"amd", GpuVendor::kAMD},
41       {"radeon", GpuVendor::kAMD},
42       {"power", GpuVendor::kPowerVR},
43   };
44   for (const auto& v : kMapping) {
45     if (gpu_description.find(v.first) != std::string::npos) {
46       return v.second;
47     }
48   }
49   return GpuVendor::kUnknown;
50 }
51 
GetAdrenoGpuVersion(const std::string & gpu_description)52 AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
53   const std::map<std::string, AdrenoGpu> kMapping = {
54       // Adreno 7xx series
55       {"730", AdrenoGpu::kAdreno730},
56       // Adreno 6xx series
57       {"685", AdrenoGpu::kAdreno685},
58       {"680", AdrenoGpu::kAdreno680},
59       {"675", AdrenoGpu::kAdreno675},
60       {"660", AdrenoGpu::kAdreno660},
61       {"650", AdrenoGpu::kAdreno650},
62       {"640", AdrenoGpu::kAdreno640},
63       {"630", AdrenoGpu::kAdreno630},
64       {"620", AdrenoGpu::kAdreno620},
65       {"618", AdrenoGpu::kAdreno618},
66       {"616", AdrenoGpu::kAdreno616},
67       {"615", AdrenoGpu::kAdreno615},
68       {"612", AdrenoGpu::kAdreno612},
69       {"610", AdrenoGpu::kAdreno610},
70       {"605", AdrenoGpu::kAdreno605},
71       // Adreno 5xx series
72       {"540", AdrenoGpu::kAdreno540},
73       {"530", AdrenoGpu::kAdreno530},
74       {"512", AdrenoGpu::kAdreno512},
75       {"510", AdrenoGpu::kAdreno510},
76       {"509", AdrenoGpu::kAdreno509},
77       {"508", AdrenoGpu::kAdreno508},
78       {"506", AdrenoGpu::kAdreno506},
79       {"505", AdrenoGpu::kAdreno505},
80       {"504", AdrenoGpu::kAdreno504},
81       // Adreno 4xx series
82       {"430", AdrenoGpu::kAdreno430},
83       {"420", AdrenoGpu::kAdreno420},
84       {"418", AdrenoGpu::kAdreno418},
85       {"405", AdrenoGpu::kAdreno405},
86       // Adreno 3xx series
87       {"330", AdrenoGpu::kAdreno330},
88       {"320", AdrenoGpu::kAdreno320},
89       {"308", AdrenoGpu::kAdreno308},
90       {"306", AdrenoGpu::kAdreno306},
91       {"305", AdrenoGpu::kAdreno305},
92       {"304", AdrenoGpu::kAdreno304},
93       // Adreno 2xx series
94       {"225", AdrenoGpu::kAdreno225},
95       {"220", AdrenoGpu::kAdreno220},
96       {"205", AdrenoGpu::kAdreno205},
97       {"203", AdrenoGpu::kAdreno203},
98       {"200", AdrenoGpu::kAdreno200},
99       // Adreno 1xx series
100       {"130", AdrenoGpu::kAdreno130},
101       {"120", AdrenoGpu::kAdreno120},
102   };
103 
104   for (const auto& v : kMapping) {
105     if (gpu_description.find(v.first) != std::string::npos) {
106       return v.second;
107     }
108   }
109   return AdrenoGpu::kUnknown;
110 }
111 
GetMaliGpuVersion(const std::string & gpu_description)112 MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
113   // Order must be preserved
114   const std::vector<std::pair<std::string, MaliGpu>> kMapping = {
115       {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
116       {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
117       {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
118       {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
119       {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
120       {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
121       {"g310", MaliGpu::kG310}, {"g31", MaliGpu::kG31},
122       {"g510", MaliGpu::kG510}, {"g51", MaliGpu::kG51},
123       {"g52", MaliGpu::kG52},   {"g57", MaliGpu::kG57},
124       {"g610", MaliGpu::kG610}, {"g68", MaliGpu::kG68},
125       {"g710", MaliGpu::kG710}, {"g71", MaliGpu::kG71},
126       {"g72", MaliGpu::kG72},   {"g76", MaliGpu::kG76},
127       {"g77", MaliGpu::kG77},   {"g78", MaliGpu::kG78},
128   };
129   for (const auto& v : kMapping) {
130     if (gpu_description.find(v.first) != std::string::npos) {
131       return v.second;
132     }
133   }
134   return MaliGpu::kUnknown;
135 }
136 
137 }  // namespace
138 
AdrenoInfo(const std::string & device_version)139 AdrenoInfo::AdrenoInfo(const std::string& device_version)
140     : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
141 
IsAdreno1xx() const142 bool AdrenoInfo::IsAdreno1xx() const {
143   return adreno_gpu == AdrenoGpu::kAdreno120 ||
144          adreno_gpu == AdrenoGpu::kAdreno130;
145 }
146 
IsAdreno2xx() const147 bool AdrenoInfo::IsAdreno2xx() const {
148   return adreno_gpu == AdrenoGpu::kAdreno200 ||
149          adreno_gpu == AdrenoGpu::kAdreno203 ||
150          adreno_gpu == AdrenoGpu::kAdreno205 ||
151          adreno_gpu == AdrenoGpu::kAdreno220 ||
152          adreno_gpu == AdrenoGpu::kAdreno225;
153 }
154 
IsAdreno3xx() const155 bool AdrenoInfo::IsAdreno3xx() const {
156   return adreno_gpu == AdrenoGpu::kAdreno304 ||
157          adreno_gpu == AdrenoGpu::kAdreno305 ||
158          adreno_gpu == AdrenoGpu::kAdreno306 ||
159          adreno_gpu == AdrenoGpu::kAdreno308 ||
160          adreno_gpu == AdrenoGpu::kAdreno320 ||
161          adreno_gpu == AdrenoGpu::kAdreno330;
162 }
163 
IsAdreno4xx() const164 bool AdrenoInfo::IsAdreno4xx() const {
165   return adreno_gpu == AdrenoGpu::kAdreno405 ||
166          adreno_gpu == AdrenoGpu::kAdreno418 ||
167          adreno_gpu == AdrenoGpu::kAdreno420 ||
168          adreno_gpu == AdrenoGpu::kAdreno430;
169 }
170 
IsAdreno5xx() const171 bool AdrenoInfo::IsAdreno5xx() const {
172   return adreno_gpu == AdrenoGpu::kAdreno504 ||
173          adreno_gpu == AdrenoGpu::kAdreno505 ||
174          adreno_gpu == AdrenoGpu::kAdreno506 ||
175          adreno_gpu == AdrenoGpu::kAdreno508 ||
176          adreno_gpu == AdrenoGpu::kAdreno509 ||
177          adreno_gpu == AdrenoGpu::kAdreno510 ||
178          adreno_gpu == AdrenoGpu::kAdreno512 ||
179          adreno_gpu == AdrenoGpu::kAdreno530 ||
180          adreno_gpu == AdrenoGpu::kAdreno540;
181 }
182 
IsAdreno6xx() const183 bool AdrenoInfo::IsAdreno6xx() const {
184   return adreno_gpu == AdrenoGpu::kAdreno605 ||
185          adreno_gpu == AdrenoGpu::kAdreno610 ||
186          adreno_gpu == AdrenoGpu::kAdreno612 ||
187          adreno_gpu == AdrenoGpu::kAdreno615 ||
188          adreno_gpu == AdrenoGpu::kAdreno616 ||
189          adreno_gpu == AdrenoGpu::kAdreno618 ||
190          adreno_gpu == AdrenoGpu::kAdreno620 ||
191          adreno_gpu == AdrenoGpu::kAdreno630 ||
192          adreno_gpu == AdrenoGpu::kAdreno640 ||
193          adreno_gpu == AdrenoGpu::kAdreno650 ||
194          adreno_gpu == AdrenoGpu::kAdreno660 ||
195          adreno_gpu == AdrenoGpu::kAdreno675 ||
196          adreno_gpu == AdrenoGpu::kAdreno680 ||
197          adreno_gpu == AdrenoGpu::kAdreno685;
198 }
199 
IsAdreno7xx() const200 bool AdrenoInfo::IsAdreno7xx() const {
201   return adreno_gpu == AdrenoGpu::kAdreno730;
202 }
203 
IsAdreno6xxOrHigher() const204 bool AdrenoInfo::IsAdreno6xxOrHigher() const {
205   return (!compiler_bugs_in_a6xx && IsAdreno6xx()) || IsAdreno7xx();
206 }
207 
GetMaximumWavesCount() const208 int AdrenoInfo::GetMaximumWavesCount() const {
209   if (IsAdreno7xx()) {
210     return 16;
211   } else if (IsAdreno6xx()) {
212     if (adreno_gpu == AdrenoGpu::kAdreno640) {
213       return 30;
214     } else {
215       return 16;
216     }
217   } else {
218     // all other versions not supported
219     return 1;
220   }
221 }
222 
GetRegisterMemorySizePerComputeUnit() const223 int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
224   if (IsAdreno7xx()) {
225     return 128 * 96 * 16;
226   } else if (IsAdreno6xx()) {
227     if (adreno_gpu == AdrenoGpu::kAdreno640) {
228       return 128 * 144 * 16;
229     } else if (adreno_gpu == AdrenoGpu::kAdreno620 ||
230                adreno_gpu == AdrenoGpu::kAdreno650 ||
231                adreno_gpu == AdrenoGpu::kAdreno660) {
232       return 128 * 64 * 16;
233     } else {
234       return 128 * 96 * 16;
235     }
236   } else {
237     // all other versions not supported
238     return 1;
239   }
240 }
241 
GetMaximumWavesCount(int register_footprint_per_tread,bool full_wave) const242 int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
243                                      bool full_wave) const {
244   const int register_usage_per_wave =
245       GetWaveSize(full_wave) * register_footprint_per_tread;
246   const int possible_waves_count =
247       GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
248   return std::min(possible_waves_count, GetMaximumWavesCount());
249 }
250 
GetWaveSize(bool full_wave) const251 int AdrenoInfo::GetWaveSize(bool full_wave) const {
252   if (IsAdreno7xx()) {
253     return full_wave ? 128 : 64;
254   } else if (IsAdreno6xx()) {
255     return full_wave ? 128 : 64;
256   } else if (IsAdreno5xx() || IsAdreno4xx()) {
257     return full_wave ? 64 : 32;
258   } else {
259     return full_wave ? 32 : 16;
260   }
261 }
262 
GetComputeUnitsCount() const263 int AdrenoInfo::GetComputeUnitsCount() const {
264   // can provide not correct numbers.
265   switch (adreno_gpu) {
266     // Adreno 7xx series
267     case AdrenoGpu::kAdreno730:
268       return 4;
269     // Adreno 6xx series
270     case AdrenoGpu::kAdreno685:
271       return 4;
272     case AdrenoGpu::kAdreno680:
273       return 4;
274     case AdrenoGpu::kAdreno675:
275       return 4;
276     case AdrenoGpu::kAdreno660:
277       return 3;
278     case AdrenoGpu::kAdreno650:
279       return 3;
280     case AdrenoGpu::kAdreno640:
281       return 2;
282     case AdrenoGpu::kAdreno630:
283       return 2;
284     case AdrenoGpu::kAdreno620:
285       return 1;
286     case AdrenoGpu::kAdreno618:
287       return 1;
288     case AdrenoGpu::kAdreno616:
289       return 1;
290     case AdrenoGpu::kAdreno615:
291       return 1;
292     case AdrenoGpu::kAdreno612:
293       return 1;
294     case AdrenoGpu::kAdreno610:
295       return 1;
296     case AdrenoGpu::kAdreno605:
297       return 1;
298     // Adreno 5xx series
299     case AdrenoGpu::kAdreno540:
300       return 4;
301     case AdrenoGpu::kAdreno530:
302       return 4;
303     case AdrenoGpu::kAdreno512:
304       return 2;
305     case AdrenoGpu::kAdreno510:
306       return 2;
307     case AdrenoGpu::kAdreno509:
308       return 2;
309     case AdrenoGpu::kAdreno508:
310       return 1;
311     case AdrenoGpu::kAdreno506:
312       return 1;
313     case AdrenoGpu::kAdreno505:
314       return 1;
315     case AdrenoGpu::kAdreno504:
316       return 1;
317     // Adreno 4xx series
318     case AdrenoGpu::kAdreno430:
319       return 4;
320     case AdrenoGpu::kAdreno420:
321       return 4;
322     case AdrenoGpu::kAdreno418:
323       return 2;
324     case AdrenoGpu::kAdreno405:
325       return 1;
326     // Adreno 3xx series
327     case AdrenoGpu::kAdreno330:
328       return 4;
329     case AdrenoGpu::kAdreno320:
330       return 2;
331     case AdrenoGpu::kAdreno308:
332       return 1;
333     case AdrenoGpu::kAdreno306:
334       return 1;
335     case AdrenoGpu::kAdreno305:
336       return 1;
337     case AdrenoGpu::kAdreno304:
338       return 1;
339     default:
340       return 1;
341   }
342 }
343 
AppleInfo(const std::string & gpu_description)344 AppleInfo::AppleInfo(const std::string& gpu_description) {
345   const std::map<std::string, AppleGpu> kMapping = {
346       {"apple a7 gpu", AppleGpu::kA7},
347       {"apple a8 gpu", AppleGpu::kA8},
348       {"apple a8x gpu", AppleGpu::kA8X},
349       {"apple a9 gpu", AppleGpu::kA9},
350       {"apple a9x gpu", AppleGpu::kA9X},
351       {"apple a10 gpu", AppleGpu::kA10},
352       {"apple a10x gpu", AppleGpu::kA10X},
353       {"apple a11 gpu", AppleGpu::kA11},
354       {"apple a12 gpu", AppleGpu::kA12},
355       {"apple a12x gpu", AppleGpu::kA12X},
356       {"apple a12z gpu", AppleGpu::kA12Z},
357       {"apple a13 gpu", AppleGpu::kA13},
358       {"apple a14 gpu", AppleGpu::kA14},
359       {"apple a15 gpu", AppleGpu::kA15},
360       // on tablets we have metal device name "apple m1 gpu"
361       // and on notebooks "apple m1"
362       {"apple m1 gpu", AppleGpu::kM1},
363       {"apple m1", AppleGpu::kM1},
364       {"apple m1 pro", AppleGpu::kM1Pro},
365       {"apple m1 max", AppleGpu::kM1Max},
366       {"apple m1 ultra", AppleGpu::kM1Ultra},
367       {"apple m2", AppleGpu::kM2},
368   };
369   auto it = kMapping.find(gpu_description);
370   if (it != kMapping.end()) {
371     gpu_type = it->second;
372   } else {
373     gpu_type = AppleGpu::kUnknown;
374   }
375 }
376 
IsA7GenerationGpu() const377 bool AppleInfo::IsA7GenerationGpu() const { return gpu_type == AppleGpu::kA7; }
IsA8GenerationGpu() const378 bool AppleInfo::IsA8GenerationGpu() const {
379   return gpu_type == AppleGpu::kA8 || gpu_type == AppleGpu::kA8X;
380 }
381 
IsLocalMemoryPreferredOverGlobal() const382 bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
383   return IsA7GenerationGpu() || IsA8GenerationGpu();
384 }
385 
IsBionic() const386 bool AppleInfo::IsBionic() const {
387   return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
388          gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
389          gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14 ||
390          gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kM1 ||
391          gpu_type == AppleGpu::kM1Pro || gpu_type == AppleGpu::kM1Max ||
392          gpu_type == AppleGpu::kM1Ultra || gpu_type == AppleGpu::kM2;
393 }
394 
IsSIMDMatMulSupported() const395 bool AppleInfo::IsSIMDMatMulSupported() const {
396   return gpu_type == AppleGpu::kA14 || gpu_type == AppleGpu::kA15 ||
397          gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro ||
398          gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra ||
399          gpu_type == AppleGpu::kM2;
400 }
401 
IsSIMDMatMulFp32Perf2x() const402 bool AppleInfo::IsSIMDMatMulFp32Perf2x() const {
403   return gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kM2;
404 }
405 
IsRoundToNearestSupported() const406 bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
407 
GetComputeUnitsCount() const408 int AppleInfo::GetComputeUnitsCount() const {
409   switch (gpu_type) {
410     case AppleGpu::kA7:
411       return 4;
412     case AppleGpu::kA8:
413       return 4;
414     case AppleGpu::kA8X:
415       return 8;
416     case AppleGpu::kA9:
417       return 6;
418     case AppleGpu::kA9X:
419       return 12;
420     case AppleGpu::kA10:
421       return 6;
422     case AppleGpu::kA10X:
423       return 12;
424     case AppleGpu::kA11:
425       return 3;
426     case AppleGpu::kA12:
427       return 4;
428     case AppleGpu::kA12X:
429       return 7;
430     case AppleGpu::kA12Z:
431       return 8;
432     case AppleGpu::kA13:
433       return 4;
434     case AppleGpu::kA14:
435       return 4;
436     // For some apple GPUs we can not receive exact CU count from name.
437     // No official Metal API to receive this info.
438     case AppleGpu::kA15:
439       if (compute_units != -1) {
440         return compute_units;
441       }
442       return 5;
443     case AppleGpu::kM1:
444       // approximate, can be 7 or 8
445       return 8;
446     case AppleGpu::kM1Pro:
447       // approximate, can be 14 or 16
448       return 16;
449     case AppleGpu::kM1Max:
450       // approximate, can be 24 or 32
451       return 32;
452     case AppleGpu::kM1Ultra:
453       // approximate, 64 is max possible
454       return 64;
455     case AppleGpu::kM2:
456       // approximate, 10 is max possible
457       return 10;
458     case AppleGpu::kUnknown:
459       return 4;
460   }
461 }
462 
SetComputeUnits(int compute_units_count)463 void AppleInfo::SetComputeUnits(int compute_units_count) {
464   compute_units = compute_units_count;
465 }
466 
MaliInfo(const std::string & gpu_description)467 MaliInfo::MaliInfo(const std::string& gpu_description)
468     : gpu_version(GetMaliGpuVersion(gpu_description)) {}
469 
IsMaliT6xx() const470 bool MaliInfo::IsMaliT6xx() const {
471   return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
472          gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
473          gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
474 }
475 
IsMaliT7xx() const476 bool MaliInfo::IsMaliT7xx() const {
477   return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
478 }
479 
IsMaliT8xx() const480 bool MaliInfo::IsMaliT8xx() const {
481   return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
482          gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
483 }
484 
IsMidgard() const485 bool MaliInfo::IsMidgard() const {
486   return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
487 }
488 
IsBifrostGen1() const489 bool MaliInfo::IsBifrostGen1() const {
490   return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
491          gpu_version == MaliGpu::kG71;
492 }
493 
IsBifrostGen2() const494 bool MaliInfo::IsBifrostGen2() const {
495   return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
496 }
497 
IsBifrostGen3() const498 bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
499 
IsBifrost() const500 bool MaliInfo::IsBifrost() const {
501   return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
502 }
503 
IsValhallGen1() const504 bool MaliInfo::IsValhallGen1() const {
505   return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77;
506 }
507 
IsValhallGen2() const508 bool MaliInfo::IsValhallGen2() const {
509   return gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
510 }
511 
IsValhallGen3() const512 bool MaliInfo::IsValhallGen3() const {
513   return gpu_version == MaliGpu::kG310 || gpu_version == MaliGpu::kG510 ||
514          gpu_version == MaliGpu::kG610 || gpu_version == MaliGpu::kG710;
515 }
516 
IsValhall() const517 bool MaliInfo::IsValhall() const {
518   return IsValhallGen1() || IsValhallGen2() || IsValhallGen3();
519 }
520 
GetApproximateComputeUnitsCount() const521 int MaliInfo::GetApproximateComputeUnitsCount() const {
522   if (IsMidgard()) {
523     // Mali Midgard can have 1-16 cores
524     return 8;
525   } else if (IsBifrost()) {
526     // Mali Bifrost can have 1-32 cores
527     return 16;
528   } else if (IsValhall()) {
529     if (gpu_version == MaliGpu::kG57) {
530       return 6;  // Mali-G57 can have 1-6 cores
531     } else if (gpu_version == MaliGpu::kG77) {
532       return 16;  // Mali-G77 can have 7-16 cores
533     } else if (gpu_version == MaliGpu::kG68) {
534       return 6;  // Mali-G68 can have 4-6 cores
535     } else if (gpu_version == MaliGpu::kG78) {
536       return 16;  // Mali-G78 can have 7-24 cores
537     } else if (gpu_version == MaliGpu::kG310 || gpu_version == MaliGpu::kG510 ||
538                gpu_version == MaliGpu::kG610) {
539       return 6;  // Mali-G310/G510/G610 can have up to 6 cores
540     } else if (gpu_version == MaliGpu::kG710) {
541       return 10;  // Mali-G710 can have 7–16 cores
542     }
543   }
544   return 4;
545 }
546 
GetGpuInfoFromDeviceDescription(const std::string & gpu_description,GpuApi gpu_api,GpuInfo * gpu_info)547 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
548                                      GpuApi gpu_api, GpuInfo* gpu_info) {
549   gpu_info->gpu_api = gpu_api;
550   std::string lowered = gpu_description;
551   absl::AsciiStrToLower(&lowered);
552   gpu_info->vendor = GetGpuVendor(lowered);
553   if (gpu_info->IsAdreno()) {
554     gpu_info->adreno_info = AdrenoInfo(lowered);
555   } else if (gpu_info->IsApple()) {
556     gpu_info->apple_info = AppleInfo(lowered);
557     gpu_info->supported_subgroup_sizes = {32};
558   } else if (gpu_info->IsMali()) {
559     gpu_info->mali_info = MaliInfo(lowered);
560   }
561 }
562 
OpenClVersionToString(OpenClVersion version)563 std::string OpenClVersionToString(OpenClVersion version) {
564   switch (version) {
565     case OpenClVersion::kCl1_0:
566       return "1.0";
567     case OpenClVersion::kCl1_1:
568       return "1.1";
569     case OpenClVersion::kCl1_2:
570       return "1.2";
571     case OpenClVersion::kCl2_0:
572       return "2.0";
573     case OpenClVersion::kCl2_1:
574       return "2.1";
575     case OpenClVersion::kCl2_2:
576       return "2.2";
577     case OpenClVersion::kCl3_0:
578       return "3.0";
579     default:
580       return "Unknown OpenCL version";
581   }
582 }
583 
SupportsExplicitFp16() const584 bool OpenGlInfo::SupportsExplicitFp16() const {
585   bool supports_f16_alu = false;
586   bool supports_f16_storage = false;
587   for (const auto& ext : extensions) {
588     if (ext == "GL_EXT_shader_explicit_arithmetic_types_float16") {
589       supports_f16_alu = true;
590     }
591     if (ext == "GL_EXT_shader_16bit_storage") {
592       supports_f16_storage = true;
593     }
594   }
595   return supports_f16_alu && supports_f16_storage;
596 }
597 
IsApiOpenGl31OrAbove() const598 bool OpenGlInfo::IsApiOpenGl31OrAbove() const {
599   return (major_version == 3 && minor_version >= 1) || major_version > 3;
600 }
601 
IsApiOpenGl32OrAbove() const602 bool OpenGlInfo::IsApiOpenGl32OrAbove() const {
603   return (major_version == 3 && minor_version >= 2) || major_version > 3;
604 }
605 
SupportsExplicitFp16() const606 bool VulkanInfo::SupportsExplicitFp16() const {
607   bool supports_f16_alu = false;
608   bool supports_f16_storage = false;
609   for (const auto& ext : extensions) {
610     if (ext == "VK_KHR_shader_float16_int8") {
611       supports_f16_alu = true;
612     }
613     if (ext == "VK_KHR_16bit_storage") {
614       supports_f16_storage = true;
615     }
616   }
617   return supports_f16_alu && supports_f16_storage;
618 }
619 
SupportsImage2D(DataType data_type,int channels) const620 bool OpenClInfo::SupportedImage2dTypes::SupportsImage2D(DataType data_type,
621                                                         int channels) const {
622   if (channels == 1) {
623     return r_layout.find(data_type) != r_layout.end();
624   } else if (channels == 2) {
625     return rg_layout.find(data_type) != rg_layout.end();
626   } else if (channels == 3) {
627     return rgb_layout.find(data_type) != rgb_layout.end();
628   } else if (channels == 4) {
629     return rgba_layout.find(data_type) != rgba_layout.end();
630   } else {
631     return false;
632   }
633 }
634 
IsImage2dFromBufferSupported() const635 bool OpenClInfo::IsImage2dFromBufferSupported() const {
636   if (image_pitch_alignment == 0) {
637     return false;
638   }
639   if (image_base_address_alignment == 0) {
640     return false;
641   }
642   if (cl_version == OpenClVersion::kCl2_0 ||
643       cl_version == OpenClVersion::kCl2_1 ||
644       cl_version == OpenClVersion::kCl2_2) {
645     return true;
646   }
647   for (const auto& ext : extensions) {
648     if (ext == "cl_khr_image2d_from_buffer") {
649       return true;
650     }
651   }
652   return false;
653 }
654 
IsSIMDMatMulSupported() const655 bool MetalInfo::IsSIMDMatMulSupported() const {
656   if (language_version == MetalLanguageVersion::kUnknown ||
657       language_version == MetalLanguageVersion::kMetal1_0 ||
658       language_version == MetalLanguageVersion::kMetal1_1 ||
659       language_version == MetalLanguageVersion::kMetal1_2 ||
660       language_version == MetalLanguageVersion::kMetal2_0 ||
661       language_version == MetalLanguageVersion::kMetal2_1 ||
662       language_version == MetalLanguageVersion::kMetal2_2) {
663     return false;
664   }
665   return true;
666 }
667 
IsMslVersionEqualOrHigher(int major,int minor) const668 bool MetalInfo::IsMslVersionEqualOrHigher(int major, int minor) const {
669   const std::map<MetalLanguageVersion, std::pair<int, int>> kMapping = {
670       {MetalLanguageVersion::kUnknown, {1, 0}},
671       {MetalLanguageVersion::kMetal1_0, {1, 0}},
672       {MetalLanguageVersion::kMetal1_1, {1, 1}},
673       {MetalLanguageVersion::kMetal1_2, {1, 2}},
674       {MetalLanguageVersion::kMetal2_0, {2, 0}},
675       {MetalLanguageVersion::kMetal2_1, {2, 1}},
676       {MetalLanguageVersion::kMetal2_2, {2, 2}},
677       {MetalLanguageVersion::kMetal2_3, {2, 3}},
678       {MetalLanguageVersion::kMetal2_4, {2, 4}},
679       {MetalLanguageVersion::kMetal3_0, {3, 0}}};
680   auto version = kMapping.at(language_version);
681   if (major > version.first) {
682     return true;
683   } else if (major == version.first && minor >= version.second) {
684     return true;
685   } else {
686     return false;
687   }
688 }
689 
IsAdreno() const690 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
691 
IsApple() const692 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
693 
IsMali() const694 bool GpuInfo::IsMali() const { return vendor == GpuVendor::kMali; }
695 
IsPowerVR() const696 bool GpuInfo::IsPowerVR() const { return vendor == GpuVendor::kPowerVR; }
697 
IsNvidia() const698 bool GpuInfo::IsNvidia() const { return vendor == GpuVendor::kNvidia; }
699 
IsAMD() const700 bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
701 
IsIntel() const702 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
703 
IsRoundToNearestSupported() const704 bool GpuInfo::IsRoundToNearestSupported() const {
705   if (IsApiOpenCl()) {
706     return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
707   }
708   if (IsApple()) {
709     return apple_info.IsRoundToNearestSupported();
710   }
711   if (IsAdreno()) {
712     if (adreno_info.IsAdreno1xx() || adreno_info.IsAdreno2xx() ||
713         adreno_info.IsAdreno3xx()) {
714       return false;
715     }
716   }
717   if (IsPowerVR()) {
718     return false;
719   }
720   return true;
721 }
722 
SupportsFP16() const723 bool GpuInfo::SupportsFP16() const {
724   if (IsApiOpenCl()) {
725     return opencl_info.supports_fp16;
726   }
727   return true;
728 }
729 
SupportsTextureArray() const730 bool GpuInfo::SupportsTextureArray() const {
731   if (!SupportsImages()) {
732     return false;
733   }
734   if (IsApiOpenCl()) {
735     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
736   }
737   return true;
738 }
739 
SupportsImageBuffer() const740 bool GpuInfo::SupportsImageBuffer() const {
741   if (!SupportsImages()) {
742     return false;
743   }
744   if (IsApiOpenCl()) {
745     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
746   }
747   return true;
748 }
749 
SupportsImage3D() const750 bool GpuInfo::SupportsImage3D() const {
751   if (!SupportsImages()) {
752     return false;
753   }
754   if (IsApiOpenCl()) {
755     if (IsMali() && mali_info.IsMidgard()) {
756       // On Mali T880 read_imageh doesn't compile with image3d_t
757       return false;
758     }
759     return opencl_info.supports_image3d_writes;
760   }
761   return true;
762 }
763 
SupportsImages() const764 bool GpuInfo::SupportsImages() const {
765   if (IsApiOpenCl()) {
766     return opencl_info.supports_images;
767   }
768   return true;
769 }
770 
SupportsPointersInKernels() const771 bool GpuInfo::SupportsPointersInKernels() const {
772   return IsApiOpenCl() || IsApiMetal();
773 }
774 
SupportsZeroClampForImageBuffer() const775 bool GpuInfo::SupportsZeroClampForImageBuffer() const {
776   if (IsApiMetal() || IsApiOpenCl()) {
777     return true;
778   } else {
779     return false;
780   }
781 }
782 
SupportsZeroClampForImages() const783 bool GpuInfo::SupportsZeroClampForImages() const {
784   if (IsApiMetal()) {
785     return true;
786   } else if (IsApiOpenCl()) {
787     return true;
788   } else if (IsApiVulkan()) {
789     return false;
790   } else if (IsApiOpenGl()) {
791     return false;
792   } else {
793     return false;
794   }
795 }
796 
IsWaveSizeEqualTo32() const797 bool GpuInfo::IsWaveSizeEqualTo32() const {
798   return supported_subgroup_sizes.size() == 1 &&
799          supported_subgroup_sizes[0] == 32;
800 }
801 
SupportsExtension(const std::string & extension) const802 bool GpuInfo::SupportsExtension(const std::string& extension) const {
803   const std::vector<std::string>* extensions = nullptr;
804   if (IsApiOpenGl()) {
805     extensions = &opengl_info.extensions;
806   } else if (IsApiVulkan()) {
807     extensions = &vulkan_info.extensions;
808   } else if (IsApiOpenCl()) {
809     extensions = &opencl_info.extensions;
810   }
811   if (!extensions) {
812     return false;
813   }
814   for (const auto& ext : *extensions) {
815     if (ext == extension) {
816       return true;
817     }
818   }
819   return false;
820 }
821 
SupportsSubGroupWithSize(int sub_group_size) const822 bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
823   for (auto subgroup_size : supported_subgroup_sizes) {
824     if (sub_group_size == subgroup_size) {
825       return true;
826     }
827   }
828   return false;
829 }
830 
SupportsFloatImage2D(DataType data_type,int channels) const831 bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
832   if (IsApiOpenCl()) {
833     return opencl_info.supported_images_2d.SupportsImage2D(data_type, channels);
834   }
835   return false;
836 }
837 
GetComputeUnitsCount() const838 int GpuInfo::GetComputeUnitsCount() const {
839   if (IsApiOpenCl()) {
840     return opencl_info.compute_units_count;
841   }
842   if (IsApple()) {
843     return apple_info.GetComputeUnitsCount();
844   }
845   if (IsAMD()) {
846     if (amd_info.GetComputeUnitsCount() != 0) {
847       return amd_info.GetComputeUnitsCount();
848     } else {
849       // approximate number
850       return 16;
851     }
852   }
853   if (IsAdreno()) {
854     return adreno_info.GetComputeUnitsCount();
855   }
856   if (IsMali()) {
857     mali_info.GetApproximateComputeUnitsCount();
858   }
859   return 4;
860 }
861 
GetMaxWorkGroupSizeForX() const862 int GpuInfo::GetMaxWorkGroupSizeForX() const {
863   if (IsApiOpenGl()) {
864     return opengl_info.max_compute_work_group_size_x;
865   }
866   if (IsApiVulkan()) {
867     return vulkan_info.max_compute_work_group_size_x;
868   }
869   if (IsApiOpenCl()) {
870     return opencl_info.max_work_group_size_x;
871   }
872   if (IsApiMetal()) {
873     return metal_info.max_work_group_size_x;
874   }
875   return 256;
876 }
877 
GetMaxWorkGroupSizeForY() const878 int GpuInfo::GetMaxWorkGroupSizeForY() const {
879   if (IsApiOpenGl()) {
880     return opengl_info.max_compute_work_group_size_y;
881   }
882   if (IsApiVulkan()) {
883     return vulkan_info.max_compute_work_group_size_y;
884   }
885   if (IsApiOpenCl()) {
886     return opencl_info.max_work_group_size_y;
887   }
888   if (IsApiMetal()) {
889     return metal_info.max_work_group_size_y;
890   }
891   return 256;
892 }
893 
GetMaxWorkGroupSizeForZ() const894 int GpuInfo::GetMaxWorkGroupSizeForZ() const {
895   if (IsApiOpenGl()) {
896     return opengl_info.max_compute_work_group_size_z;
897   }
898   if (IsApiVulkan()) {
899     return vulkan_info.max_compute_work_group_size_z;
900   }
901   if (IsApiOpenCl()) {
902     return opencl_info.max_work_group_size_z;
903   }
904   if (IsApiMetal()) {
905     return metal_info.max_work_group_size_z;
906   }
907   return 64;
908 }
909 
GetMaxWorkGroupTotalSize() const910 int GpuInfo::GetMaxWorkGroupTotalSize() const {
911   if (IsApiOpenGl()) {
912     return opengl_info.max_work_group_invocations;
913   }
914   if (IsApiVulkan()) {
915     return vulkan_info.max_compute_work_group_invocations;
916   }
917   if (IsApiOpenCl()) {
918     return opencl_info.max_work_group_total_size;
919   }
920   if (IsApiMetal()) {
921     int max_size = metal_info.max_work_group_size_x;
922     max_size = std::max(max_size, metal_info.max_work_group_size_y);
923     max_size = std::max(max_size, metal_info.max_work_group_size_z);
924     return max_size;
925   }
926   return 256;
927 }
928 
GetMaxImage2DWidth() const929 uint64_t GpuInfo::GetMaxImage2DWidth() const {
930   if (IsApiOpenGl()) {
931     return opengl_info.max_texture_size;
932   }
933   if (IsApiVulkan()) {
934     return vulkan_info.max_image_dimension_2d;
935   }
936   if (IsApiOpenCl()) {
937     return opencl_info.image2d_max_width;
938   }
939   if (IsApiMetal()) {
940     return metal_info.image2d_max_width;
941   }
942   return 2048;
943 }
944 
GetMaxImage2DHeight() const945 uint64_t GpuInfo::GetMaxImage2DHeight() const {
946   if (IsApiOpenGl()) {
947     return opengl_info.max_texture_size;
948   }
949   if (IsApiVulkan()) {
950     return vulkan_info.max_image_dimension_2d;
951   }
952   if (IsApiOpenCl()) {
953     return opencl_info.image2d_max_height;
954   }
955   if (IsApiMetal()) {
956     return metal_info.image2d_max_height;
957   }
958   return 2048;
959 }
960 
GetMaxImage2DArrayLayers() const961 uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
962   if (IsApiOpenGl()) {
963     return opengl_info.max_array_texture_layers;
964   }
965   if (IsApiVulkan()) {
966     return vulkan_info.max_image_array_layers;
967   }
968   if (IsApiOpenCl()) {
969     return opencl_info.image_array_max_layers;
970   }
971   if (IsApiMetal()) {
972     return metal_info.image_array_max_layers;
973   }
974   return 256;
975 }
976 
GetMaxImage3DWidth() const977 uint64_t GpuInfo::GetMaxImage3DWidth() const {
978   if (IsApiOpenCl()) {
979     return opencl_info.image3d_max_width;
980   } else if (IsApiMetal()) {
981     return metal_info.image3d_max_width;
982   } else if (IsApiVulkan()) {
983     return vulkan_info.max_image_dimension_3d;
984   }
985   return 256;
986 }
987 
GetMaxImage3DHeight() const988 uint64_t GpuInfo::GetMaxImage3DHeight() const {
989   if (IsApiOpenCl()) {
990     return opencl_info.image3d_max_height;
991   } else if (IsApiMetal()) {
992     return metal_info.image3d_max_height;
993   } else if (IsApiVulkan()) {
994     return vulkan_info.max_image_dimension_3d;
995   }
996   return 256;
997 }
998 
GetMaxImage3DDepth() const999 uint64_t GpuInfo::GetMaxImage3DDepth() const {
1000   if (IsApiOpenCl()) {
1001     return opencl_info.image3d_max_depth;
1002   } else if (IsApiMetal()) {
1003     return metal_info.image3d_max_depth;
1004   } else if (IsApiVulkan()) {
1005     return vulkan_info.max_image_dimension_3d;
1006   }
1007   return 256;
1008 }
1009 
GetMaxBufferSize() const1010 uint64_t GpuInfo::GetMaxBufferSize() const {
1011   if (IsApiOpenCl()) {
1012     return opencl_info.buffer_max_size;
1013   } else if (IsApiMetal()) {
1014     return metal_info.buffer_max_size;
1015   } else if (IsApiVulkan()) {
1016     return vulkan_info.max_storage_buffer_range;
1017   }
1018   return 128 * 1024 * 1024;
1019 }
1020 
GetMaxMemoryAllocationSize() const1021 uint64_t GpuInfo::GetMaxMemoryAllocationSize() const {
1022   if (IsApiOpenCl()) {
1023     return opencl_info.max_allocation_size;
1024   } else if (IsApiMetal()) {
1025     return metal_info.buffer_max_size;
1026   } else if (IsApiVulkan()) {
1027     return vulkan_info.max_storage_buffer_range;
1028   }
1029   return 128 * 1024 * 1024;
1030 }
1031 
GetMaxImageBufferWidth() const1032 uint64_t GpuInfo::GetMaxImageBufferWidth() const {
1033   if (IsApiOpenCl()) {
1034     return opencl_info.image_buffer_max_size;
1035   } else if (IsApiVulkan()) {
1036     return vulkan_info.max_texel_buffer_elements;
1037   }
1038   return 64 * 1024;
1039 }
1040 
GetMaxImageArguments() const1041 int GpuInfo::GetMaxImageArguments() const {
1042   if (IsApiOpenGl()) {
1043     return opengl_info.max_image_units;
1044   }
1045   if (IsApiVulkan()) {
1046     return vulkan_info.max_per_stage_descriptor_sampled_images;
1047   }
1048   if (IsApiMetal()) {
1049     return 32;
1050   }
1051   if (IsApiOpenCl()) {
1052     return 128;
1053   }
1054   return 1;
1055 }
1056 
IsApiOpenGl() const1057 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
1058 
IsApiOpenGl31OrAbove() const1059 bool GpuInfo::IsApiOpenGl31OrAbove() const {
1060   if (!IsApiOpenGl()) {
1061     return false;
1062   }
1063   return opengl_info.IsApiOpenGl31OrAbove();
1064 }
1065 
IsApiVulkan() const1066 bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
1067 
IsApiMetal() const1068 bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
1069 
IsApiOpenCl() const1070 bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
1071 
IsGlsl() const1072 bool GpuInfo::IsGlsl() const { return IsApiOpenGl() || IsApiVulkan(); }
1073 
IsGlslSupportsExplicitFp16() const1074 bool GpuInfo::IsGlslSupportsExplicitFp16() const {
1075   if (IsApiOpenGl() && opengl_info.SupportsExplicitFp16()) {
1076     return true;
1077   }
1078   if (IsApiVulkan() && vulkan_info.SupportsExplicitFp16()) {
1079     return true;
1080   }
1081   return false;
1082 }
1083 
IsCL11OrHigher() const1084 bool GpuInfo::IsCL11OrHigher() const {
1085   if (!IsApiOpenCl()) {
1086     return false;
1087   }
1088   return opencl_info.cl_version != OpenClVersion::kCl1_0;
1089 }
1090 
IsCL20OrHigher() const1091 bool GpuInfo::IsCL20OrHigher() const {
1092   if (!IsApiOpenCl()) {
1093     return false;
1094   }
1095   return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
1096          opencl_info.cl_version != OpenClVersion::kCl1_1 &&
1097          opencl_info.cl_version != OpenClVersion::kCl1_2;
1098 }
1099 
IsCL30OrHigher() const1100 bool GpuInfo::IsCL30OrHigher() const {
1101   if (!IsApiOpenCl()) {
1102     return false;
1103   }
1104   return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
1105          opencl_info.cl_version != OpenClVersion::kCl2_1 &&
1106          opencl_info.cl_version != OpenClVersion::kCl2_2;
1107 }
1108 
1109 }  // namespace gpu
1110 }  // namespace tflite
1111