• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <string>
21 
22 #include "absl/strings/ascii.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace {
27 
GetGpuVendor(const std::string & gpu_description)28 GpuVendor GetGpuVendor(const std::string& gpu_description) {
29   const std::map<std::string, GpuVendor> kMapping = {
30       {"adreno", GpuVendor::kQualcomm},
31       {"apple", GpuVendor::kApple},
32       {"qualcomm", GpuVendor::kQualcomm},
33       {"mali", GpuVendor::kMali},
34       {"powervr", GpuVendor::kPowerVR},
35       {"advanced micro devices", GpuVendor::kAMD},
36       {"intel", GpuVendor::kIntel},
37       {"nvidia", GpuVendor::kNvidia},
38       {"amd", GpuVendor::kAMD},
39       {"power", GpuVendor::kPowerVR},
40   };
41   for (const auto& v : kMapping) {
42     if (gpu_description.find(v.first) != std::string::npos) {
43       return v.second;
44     }
45   }
46   return GpuVendor::kUnknown;
47 }
48 
GetAdrenoGpuVersion(const std::string & gpu_description)49 AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
50   const std::map<std::string, AdrenoGpu> kMapping = {
51       // Adreno 6xx series
52       {"685", AdrenoGpu::kAdreno685},
53       {"680", AdrenoGpu::kAdreno680},
54       {"675", AdrenoGpu::kAdreno675},
55       {"660", AdrenoGpu::kAdreno660},
56       {"650", AdrenoGpu::kAdreno650},
57       {"640", AdrenoGpu::kAdreno640},
58       {"630", AdrenoGpu::kAdreno630},
59       {"620", AdrenoGpu::kAdreno620},
60       {"618", AdrenoGpu::kAdreno618},
61       {"616", AdrenoGpu::kAdreno616},
62       {"615", AdrenoGpu::kAdreno615},
63       {"612", AdrenoGpu::kAdreno612},
64       {"610", AdrenoGpu::kAdreno610},
65       {"605", AdrenoGpu::kAdreno605},
66       // Adreno 5xx series
67       {"540", AdrenoGpu::kAdreno540},
68       {"530", AdrenoGpu::kAdreno530},
69       {"512", AdrenoGpu::kAdreno512},
70       {"510", AdrenoGpu::kAdreno510},
71       {"509", AdrenoGpu::kAdreno509},
72       {"508", AdrenoGpu::kAdreno508},
73       {"506", AdrenoGpu::kAdreno506},
74       {"505", AdrenoGpu::kAdreno505},
75       {"504", AdrenoGpu::kAdreno504},
76       // Adreno 4xx series
77       {"430", AdrenoGpu::kAdreno430},
78       {"420", AdrenoGpu::kAdreno420},
79       {"418", AdrenoGpu::kAdreno418},
80       {"405", AdrenoGpu::kAdreno405},
81       // Adreno 3xx series
82       {"330", AdrenoGpu::kAdreno330},
83       {"320", AdrenoGpu::kAdreno320},
84       {"308", AdrenoGpu::kAdreno308},
85       {"306", AdrenoGpu::kAdreno306},
86       {"305", AdrenoGpu::kAdreno305},
87       {"304", AdrenoGpu::kAdreno304},
88       // Adreno 2xx series
89       {"225", AdrenoGpu::kAdreno225},
90       {"220", AdrenoGpu::kAdreno220},
91       {"205", AdrenoGpu::kAdreno205},
92       {"203", AdrenoGpu::kAdreno203},
93       {"200", AdrenoGpu::kAdreno200},
94       // Adreno 1xx series
95       {"130", AdrenoGpu::kAdreno130},
96       {"120", AdrenoGpu::kAdreno120},
97   };
98 
99   for (const auto& v : kMapping) {
100     if (gpu_description.find(v.first) != std::string::npos) {
101       return v.second;
102     }
103   }
104   return AdrenoGpu::kUnknown;
105 }
106 
GetMaliGpuVersion(const std::string & gpu_description)107 MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
108   const std::map<std::string, MaliGpu> kMapping = {
109       {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
110       {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
111       {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
112       {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
113       {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
114       {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
115       {"g31", MaliGpu::kG31},   {"g51", MaliGpu::kG51},
116       {"g71", MaliGpu::kG71},   {"g52", MaliGpu::kG52},
117       {"g72", MaliGpu::kG72},   {"g76", MaliGpu::kG76},
118       {"g57", MaliGpu::kG57},   {"g77", MaliGpu::kG77},
119       {"g68", MaliGpu::kG68},   {"g78", MaliGpu::kG78},
120   };
121   for (const auto& v : kMapping) {
122     if (gpu_description.find(v.first) != std::string::npos) {
123       return v.second;
124     }
125   }
126   return MaliGpu::kUnknown;
127 }
128 
129 }  // namespace
130 
AdrenoInfo(const std::string & device_version)131 AdrenoInfo::AdrenoInfo(const std::string& device_version)
132     : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
133 
IsAdreno1xx() const134 bool AdrenoInfo::IsAdreno1xx() const {
135   return adreno_gpu == AdrenoGpu::kAdreno120 ||
136          adreno_gpu == AdrenoGpu::kAdreno130;
137 }
138 
IsAdreno2xx() const139 bool AdrenoInfo::IsAdreno2xx() const {
140   return adreno_gpu == AdrenoGpu::kAdreno200 ||
141          adreno_gpu == AdrenoGpu::kAdreno203 ||
142          adreno_gpu == AdrenoGpu::kAdreno205 ||
143          adreno_gpu == AdrenoGpu::kAdreno220 ||
144          adreno_gpu == AdrenoGpu::kAdreno225;
145 }
146 
IsAdreno3xx() const147 bool AdrenoInfo::IsAdreno3xx() const {
148   return adreno_gpu == AdrenoGpu::kAdreno304 ||
149          adreno_gpu == AdrenoGpu::kAdreno305 ||
150          adreno_gpu == AdrenoGpu::kAdreno306 ||
151          adreno_gpu == AdrenoGpu::kAdreno308 ||
152          adreno_gpu == AdrenoGpu::kAdreno320 ||
153          adreno_gpu == AdrenoGpu::kAdreno330;
154 }
155 
IsAdreno4xx() const156 bool AdrenoInfo::IsAdreno4xx() const {
157   return adreno_gpu == AdrenoGpu::kAdreno405 ||
158          adreno_gpu == AdrenoGpu::kAdreno418 ||
159          adreno_gpu == AdrenoGpu::kAdreno420 ||
160          adreno_gpu == AdrenoGpu::kAdreno430;
161 }
162 
IsAdreno5xx() const163 bool AdrenoInfo::IsAdreno5xx() const {
164   return adreno_gpu == AdrenoGpu::kAdreno504 ||
165          adreno_gpu == AdrenoGpu::kAdreno505 ||
166          adreno_gpu == AdrenoGpu::kAdreno506 ||
167          adreno_gpu == AdrenoGpu::kAdreno508 ||
168          adreno_gpu == AdrenoGpu::kAdreno509 ||
169          adreno_gpu == AdrenoGpu::kAdreno510 ||
170          adreno_gpu == AdrenoGpu::kAdreno512 ||
171          adreno_gpu == AdrenoGpu::kAdreno530 ||
172          adreno_gpu == AdrenoGpu::kAdreno540;
173 }
174 
IsAdreno6xx() const175 bool AdrenoInfo::IsAdreno6xx() const {
176   return adreno_gpu == AdrenoGpu::kAdreno605 ||
177          adreno_gpu == AdrenoGpu::kAdreno610 ||
178          adreno_gpu == AdrenoGpu::kAdreno612 ||
179          adreno_gpu == AdrenoGpu::kAdreno615 ||
180          adreno_gpu == AdrenoGpu::kAdreno616 ||
181          adreno_gpu == AdrenoGpu::kAdreno618 ||
182          adreno_gpu == AdrenoGpu::kAdreno620 ||
183          adreno_gpu == AdrenoGpu::kAdreno630 ||
184          adreno_gpu == AdrenoGpu::kAdreno640 ||
185          adreno_gpu == AdrenoGpu::kAdreno650 ||
186          adreno_gpu == AdrenoGpu::kAdreno660 ||
187          adreno_gpu == AdrenoGpu::kAdreno675 ||
188          adreno_gpu == AdrenoGpu::kAdreno680 ||
189          adreno_gpu == AdrenoGpu::kAdreno685;
190 }
191 
IsAdreno6xxOrHigher() const192 bool AdrenoInfo::IsAdreno6xxOrHigher() const {
193   return !compiler_bugs_in_a6xx && IsAdreno6xx();
194 }
195 
GetMaximumWavesCount() const196 int AdrenoInfo::GetMaximumWavesCount() const {
197   if (IsAdreno6xx()) {
198     if (adreno_gpu == AdrenoGpu::kAdreno640) {
199       return 30;
200     } else {
201       return 16;
202     }
203   } else {
204     // all other versions not supported
205     return 1;
206   }
207 }
208 
GetRegisterMemorySizePerComputeUnit() const209 int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
210   if (IsAdreno6xx()) {
211     if (adreno_gpu == AdrenoGpu::kAdreno640) {
212       return 128 * 144 * 16;
213     } else if (adreno_gpu == AdrenoGpu::kAdreno620 ||
214                adreno_gpu == AdrenoGpu::kAdreno650 ||
215                adreno_gpu == AdrenoGpu::kAdreno660) {
216       return 128 * 64 * 16;
217     } else {
218       return 128 * 96 * 16;
219     }
220   } else {
221     // all other versions not supported
222     return 1;
223   }
224 }
225 
GetMaximumWavesCount(int register_footprint_per_tread,bool full_wave) const226 int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
227                                      bool full_wave) const {
228   const int register_usage_per_wave =
229       GetWaveSize(full_wave) * register_footprint_per_tread;
230   const int possible_waves_count =
231       GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
232   return std::min(possible_waves_count, GetMaximumWavesCount());
233 }
234 
GetWaveSize(bool full_wave) const235 int AdrenoInfo::GetWaveSize(bool full_wave) const {
236   if (IsAdreno6xx()) {
237     return full_wave ? 128 : 64;
238   } else if (IsAdreno5xx() || IsAdreno4xx()) {
239     return full_wave ? 64 : 32;
240   } else {
241     // all other versions not supported
242     return 1;
243   }
244 }
245 
AppleInfo(const std::string & gpu_description)246 AppleInfo::AppleInfo(const std::string& gpu_description) {
247   const std::map<std::string, AppleGpu> kMapping = {
248       {"apple a7 gpu", AppleGpu::kA7},     {"apple a8 gpu", AppleGpu::kA8},
249       {"apple a8x gpu", AppleGpu::kA8X},   {"apple a9 gpu", AppleGpu::kA9},
250       {"apple a9x gpu", AppleGpu::kA9X},   {"apple a10 gpu", AppleGpu::kA10},
251       {"apple a10x gpu", AppleGpu::kA10X}, {"apple a11 gpu", AppleGpu::kA11},
252       {"apple a12 gpu", AppleGpu::kA12},   {"apple a12x gpu", AppleGpu::kA12X},
253       {"apple a12z gpu", AppleGpu::kA12Z}, {"apple a13 gpu", AppleGpu::kA13},
254       {"apple a14 gpu", AppleGpu::kA14},
255   };
256   auto it = kMapping.find(gpu_description);
257   if (it != kMapping.end()) {
258     gpu_type = it->second;
259   } else {
260     gpu_type = AppleGpu::kUnknown;
261   }
262 }
263 
IsLocalMemoryPreferredOverGlobal() const264 bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
265   return gpu_type == AppleGpu::kA7 || gpu_type == AppleGpu::kA8 ||
266          gpu_type == AppleGpu::kA8X;
267 }
268 
IsBionic() const269 bool AppleInfo::IsBionic() const {
270   return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
271          gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
272          gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14;
273 }
274 
IsRoundToNearestSupported() const275 bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
276 
GetComputeUnitsCount() const277 int AppleInfo::GetComputeUnitsCount() const {
278   switch (gpu_type) {
279     case AppleGpu::kA7:
280       return 4;
281     case AppleGpu::kA8:
282       return 4;
283     case AppleGpu::kA8X:
284       return 8;
285     case AppleGpu::kA9:
286       return 6;
287     case AppleGpu::kA9X:
288       return 12;
289     case AppleGpu::kA10:
290       return 6;
291     case AppleGpu::kA10X:
292       return 12;
293     case AppleGpu::kA11:
294       return 3;
295     case AppleGpu::kA12:
296       return 4;
297     case AppleGpu::kA12X:
298       return 7;
299     case AppleGpu::kA12Z:
300       return 8;
301     case AppleGpu::kA13:
302       return 4;
303     case AppleGpu::kA14:
304       return 4;
305     case AppleGpu::kUnknown:
306       return 1;
307   }
308 }
309 
MaliInfo(const std::string & gpu_description)310 MaliInfo::MaliInfo(const std::string& gpu_description)
311     : gpu_version(GetMaliGpuVersion(gpu_description)) {}
312 
IsMaliT6xx() const313 bool MaliInfo::IsMaliT6xx() const {
314   return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
315          gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
316          gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
317 }
318 
IsMaliT7xx() const319 bool MaliInfo::IsMaliT7xx() const {
320   return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
321 }
322 
IsMaliT8xx() const323 bool MaliInfo::IsMaliT8xx() const {
324   return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
325          gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
326 }
327 
IsMidgard() const328 bool MaliInfo::IsMidgard() const {
329   return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
330 }
331 
IsBifrostGen1() const332 bool MaliInfo::IsBifrostGen1() const {
333   return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
334          gpu_version == MaliGpu::kG71;
335 }
336 
IsBifrostGen2() const337 bool MaliInfo::IsBifrostGen2() const {
338   return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
339 }
340 
IsBifrostGen3() const341 bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
342 
IsBifrost() const343 bool MaliInfo::IsBifrost() const {
344   return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
345 }
346 
IsValhall() const347 bool MaliInfo::IsValhall() const {
348   return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 ||
349          gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
350 }
351 
GetGpuInfoFromDeviceDescription(const std::string & gpu_description,GpuApi gpu_api,GpuInfo * gpu_info)352 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
353                                      GpuApi gpu_api, GpuInfo* gpu_info) {
354   gpu_info->gpu_api = gpu_api;
355   std::string lowered = gpu_description;
356   absl::AsciiStrToLower(&lowered);
357   gpu_info->vendor = GetGpuVendor(lowered);
358   if (gpu_info->IsAdreno()) {
359     gpu_info->adreno_info = AdrenoInfo(lowered);
360   } else if (gpu_info->IsApple()) {
361     gpu_info->apple_info = AppleInfo(lowered);
362     gpu_info->supported_subgroup_sizes = {32};
363   } else if (gpu_info->IsMali()) {
364     gpu_info->mali_info = MaliInfo(lowered);
365   }
366 }
367 
OpenClVersionToString(OpenClVersion version)368 std::string OpenClVersionToString(OpenClVersion version) {
369   switch (version) {
370     case OpenClVersion::kCl1_0:
371       return "1.0";
372     case OpenClVersion::kCl1_1:
373       return "1.1";
374     case OpenClVersion::kCl1_2:
375       return "1.2";
376     case OpenClVersion::kCl2_0:
377       return "2.0";
378     case OpenClVersion::kCl2_1:
379       return "2.1";
380     case OpenClVersion::kCl2_2:
381       return "2.2";
382     case OpenClVersion::kCl3_0:
383       return "3.0";
384     default:
385       return "Unknown OpenCL version";
386   }
387 }
388 
IsImage2dFromBufferSupported() const389 bool OpenClInfo::IsImage2dFromBufferSupported() const {
390   if (image_pitch_alignment == 0) {
391     return false;
392   }
393   if (cl_version == OpenClVersion::kCl2_0 ||
394       cl_version == OpenClVersion::kCl2_1 ||
395       cl_version == OpenClVersion::kCl2_2) {
396     return true;
397   }
398   for (const auto& ext : extensions) {
399     if (ext == "cl_khr_image2d_from_buffer") {
400       return true;
401     }
402   }
403   return false;
404 }
405 
IsAdreno() const406 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
407 
IsApple() const408 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
409 
IsMali() const410 bool GpuInfo::IsMali() const { return vendor == GpuVendor::kMali; }
411 
IsPowerVR() const412 bool GpuInfo::IsPowerVR() const { return vendor == GpuVendor::kPowerVR; }
413 
IsNvidia() const414 bool GpuInfo::IsNvidia() const { return vendor == GpuVendor::kNvidia; }
415 
IsAMD() const416 bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
417 
IsIntel() const418 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
419 
IsRoundToNearestSupported() const420 bool GpuInfo::IsRoundToNearestSupported() const {
421   if (IsApiOpenCl()) {
422     return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
423   }
424   if (IsApple()) {
425     return apple_info.IsRoundToNearestSupported();
426   }
427   if (IsAdreno()) {
428     if (adreno_info.IsAdreno1xx() || adreno_info.IsAdreno2xx() ||
429         adreno_info.IsAdreno3xx()) {
430       return false;
431     }
432   }
433   if (IsPowerVR()) {
434     return false;
435   }
436   return true;
437 }
438 
SupportsFP16() const439 bool GpuInfo::SupportsFP16() const {
440   if (IsApiOpenCl()) {
441     return opencl_info.supports_fp16;
442   }
443   return true;
444 }
445 
SupportsTextureArray() const446 bool GpuInfo::SupportsTextureArray() const {
447   if (!SupportsImages()) {
448     return false;
449   }
450   if (IsApiOpenCl()) {
451     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
452   }
453   return true;
454 }
455 
SupportsImageBuffer() const456 bool GpuInfo::SupportsImageBuffer() const {
457   if (!SupportsImages()) {
458     return false;
459   }
460   if (IsApiOpenCl()) {
461     return opencl_info.cl_version >= OpenClVersion::kCl1_2;
462   }
463   return true;
464 }
465 
SupportsImage3D() const466 bool GpuInfo::SupportsImage3D() const {
467   if (!SupportsImages()) {
468     return false;
469   }
470   if (IsApiOpenCl()) {
471     if (IsMali() && mali_info.IsMidgard()) {
472       // On Mali T880 read_imageh doesn't compile with image3d_t
473       return false;
474     }
475     return opencl_info.supports_image3d_writes;
476   }
477   return true;
478 }
479 
SupportsImages() const480 bool GpuInfo::SupportsImages() const {
481   if (IsApiOpenCl()) {
482     return opencl_info.supports_images;
483   }
484   return true;
485 }
486 
SupportsPointersInKernels() const487 bool GpuInfo::SupportsPointersInKernels() const {
488   return IsApiOpenCl() || IsApiMetal();
489 }
490 
IsWaveSizeEqualTo32() const491 bool GpuInfo::IsWaveSizeEqualTo32() const {
492   return supported_subgroup_sizes.size() == 1 &&
493          supported_subgroup_sizes[0] == 32;
494 }
495 
SupportsExtension(const std::string & extension) const496 bool GpuInfo::SupportsExtension(const std::string& extension) const {
497   const std::vector<std::string>* extensions = nullptr;
498   if (IsApiOpenGl()) {
499     extensions = &opengl_info.extensions;
500   } else if (IsApiVulkan()) {
501     extensions = &vulkan_info.extensions;
502   } else if (IsApiOpenCl()) {
503     extensions = &opencl_info.extensions;
504   }
505   if (!extensions) {
506     return false;
507   }
508   for (const auto& ext : *extensions) {
509     if (ext == extension) {
510       return true;
511     }
512   }
513   return false;
514 }
515 
SupportsSubGroupWithSize(int sub_group_size) const516 bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
517   for (auto subgroup_size : supported_subgroup_sizes) {
518     if (sub_group_size == subgroup_size) {
519       return true;
520     }
521   }
522   return false;
523 }
524 
SupportsFloatImage2D(DataType data_type,int channels) const525 bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
526   if (IsApiOpenCl()) {
527     if (channels == 1) {
528       return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d
529                                             : opencl_info.supports_r_f16_tex2d;
530     } else if (channels == 2) {
531       return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d
532                                             : opencl_info.supports_rg_f16_tex2d;
533     } else if (channels == 3) {
534       return data_type == DataType::FLOAT32
535                  ? opencl_info.supports_rgb_f32_tex2d
536                  : opencl_info.supports_rgb_f16_tex2d;
537     } else if (channels == 4) {
538       return data_type == DataType::FLOAT32
539                  ? opencl_info.supports_rgba_f32_tex2d
540                  : opencl_info.supports_rgba_f16_tex2d;
541     } else {
542       return false;
543     }
544   }
545   return false;
546 }
547 
GetComputeUnitsCount() const548 int GpuInfo::GetComputeUnitsCount() const {
549   if (IsApiOpenCl()) {
550     return opencl_info.compute_units_count;
551   }
552   if (IsApple()) {
553     return apple_info.GetComputeUnitsCount();
554   }
555   if (IsAMD() && IsApiVulkan()) {
556     return amd_info.GetComputeUnitsCount();
557   }
558   return 1;
559 }
560 
GetMaxWorkGroupSizeForX() const561 int GpuInfo::GetMaxWorkGroupSizeForX() const {
562   if (IsApiOpenGl()) {
563     return opengl_info.max_compute_work_group_size_x;
564   }
565   if (IsApiVulkan()) {
566     return vulkan_info.max_compute_work_group_size_x;
567   }
568   if (IsApiOpenCl()) {
569     return opencl_info.max_work_group_size_x;
570   }
571   if (IsApiMetal()) {
572     return metal_info.max_work_group_size_x;
573   }
574   return 256;
575 }
576 
GetMaxWorkGroupSizeForY() const577 int GpuInfo::GetMaxWorkGroupSizeForY() const {
578   if (IsApiOpenGl()) {
579     return opengl_info.max_compute_work_group_size_y;
580   }
581   if (IsApiVulkan()) {
582     return vulkan_info.max_compute_work_group_size_y;
583   }
584   if (IsApiOpenCl()) {
585     return opencl_info.max_work_group_size_y;
586   }
587   if (IsApiMetal()) {
588     return metal_info.max_work_group_size_y;
589   }
590   return 256;
591 }
592 
GetMaxWorkGroupSizeForZ() const593 int GpuInfo::GetMaxWorkGroupSizeForZ() const {
594   if (IsApiOpenGl()) {
595     return opengl_info.max_compute_work_group_size_z;
596   }
597   if (IsApiVulkan()) {
598     return vulkan_info.max_compute_work_group_size_z;
599   }
600   if (IsApiOpenCl()) {
601     return opencl_info.max_work_group_size_z;
602   }
603   if (IsApiMetal()) {
604     return metal_info.max_work_group_size_z;
605   }
606   return 64;
607 }
608 
GetMaxWorkGroupTotalSize() const609 int GpuInfo::GetMaxWorkGroupTotalSize() const {
610   if (IsApiOpenGl()) {
611     return opengl_info.max_work_group_invocations;
612   }
613   if (IsApiVulkan()) {
614     return vulkan_info.max_compute_work_group_invocations;
615   }
616   if (IsApiOpenCl()) {
617     return opencl_info.max_work_group_total_size;
618   }
619   if (IsApiMetal()) {
620     int max_size = metal_info.max_work_group_size_x;
621     max_size = std::max(max_size, metal_info.max_work_group_size_y);
622     max_size = std::max(max_size, metal_info.max_work_group_size_z);
623     return max_size;
624   }
625   return 256;
626 }
627 
GetMaxImage2DWidth() const628 uint64_t GpuInfo::GetMaxImage2DWidth() const {
629   if (IsApiOpenGl()) {
630     return opengl_info.max_texture_size;
631   }
632   if (IsApiVulkan()) {
633     return vulkan_info.max_image_dimension_2d;
634   }
635   if (IsApiOpenCl()) {
636     return opencl_info.image2d_max_width;
637   }
638   return 2048;
639 }
640 
GetMaxImage2DHeight() const641 uint64_t GpuInfo::GetMaxImage2DHeight() const {
642   if (IsApiOpenGl()) {
643     return opengl_info.max_texture_size;
644   }
645   if (IsApiVulkan()) {
646     return vulkan_info.max_image_dimension_2d;
647   }
648   if (IsApiOpenCl()) {
649     return opencl_info.image2d_max_height;
650   }
651   return 2048;
652 }
653 
GetMaxImage2DArrayLayers() const654 uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
655   if (IsApiOpenGl()) {
656     return opengl_info.max_array_texture_layers;
657   }
658   if (IsApiVulkan()) {
659     return vulkan_info.max_image_array_layers;
660   }
661   if (IsApiOpenCl()) {
662     return opencl_info.image_array_max_layers;
663   }
664   return 256;
665 }
666 
GetMaxImage3DWidth() const667 uint64_t GpuInfo::GetMaxImage3DWidth() const {
668   if (IsApiOpenCl()) {
669     return opencl_info.image3d_max_width;
670   }
671   return 256;
672 }
673 
GetMaxImage3DHeight() const674 uint64_t GpuInfo::GetMaxImage3DHeight() const {
675   if (IsApiOpenCl()) {
676     return opencl_info.image3d_max_height;
677   }
678   return 256;
679 }
680 
GetMaxImage3DDepth() const681 uint64_t GpuInfo::GetMaxImage3DDepth() const {
682   if (IsApiOpenCl()) {
683     return opencl_info.image3d_max_depth;
684   }
685   return 256;
686 }
687 
GetMaxBufferSize() const688 uint64_t GpuInfo::GetMaxBufferSize() const {
689   if (IsApiOpenCl()) {
690     return opencl_info.buffer_max_size;
691   } else if (IsApiMetal()) {
692     return metal_info.buffer_max_size;
693   }
694   return 128 * 1024 * 1024;
695 }
696 
GetMaxMemoryAllocationSize() const697 uint64_t GpuInfo::GetMaxMemoryAllocationSize() const {
698   if (IsApiOpenCl()) {
699     return opencl_info.max_allocation_size;
700   } else if (IsApiMetal()) {
701     return metal_info.buffer_max_size;
702   }
703   return 128 * 1024 * 1024;
704 }
705 
GetMaxImageBufferWidth() const706 uint64_t GpuInfo::GetMaxImageBufferWidth() const {
707   if (IsApiOpenCl()) {
708     return opencl_info.image_buffer_max_size;
709   }
710   return 64 * 1024;
711 }
712 
GetMaxImageArguments() const713 int GpuInfo::GetMaxImageArguments() const {
714   if (IsApiOpenGl()) {
715     return opengl_info.max_image_units;
716   }
717   if (IsApiVulkan()) {
718     return vulkan_info.max_per_stage_descriptor_sampled_images;
719   }
720   if (IsApiMetal()) {
721     return 32;
722   }
723   if (IsApiOpenCl()) {
724     return 128;
725   }
726   return 1;
727 }
728 
IsApiOpenGl() const729 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
730 
IsApiOpenGl31OrAbove() const731 bool GpuInfo::IsApiOpenGl31OrAbove() const {
732   if (!IsApiOpenGl()) {
733     return false;
734   }
735   return (opengl_info.major_version == 3 && opengl_info.minor_version >= 1) ||
736          opengl_info.major_version > 3;
737 }
738 
IsApiVulkan() const739 bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
740 
IsApiMetal() const741 bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
742 
IsApiOpenCl() const743 bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
744 
IsGlsl() const745 bool GpuInfo::IsGlsl() const { return IsApiOpenGl() || IsApiVulkan(); }
746 
IsCL11OrHigher() const747 bool GpuInfo::IsCL11OrHigher() const {
748   if (!IsApiOpenCl()) {
749     return false;
750   }
751   return opencl_info.cl_version != OpenClVersion::kCl1_0;
752 }
753 
IsCL20OrHigher() const754 bool GpuInfo::IsCL20OrHigher() const {
755   if (!IsApiOpenCl()) {
756     return false;
757   }
758   return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
759          opencl_info.cl_version != OpenClVersion::kCl1_1 &&
760          opencl_info.cl_version != OpenClVersion::kCl1_2;
761 }
762 
IsCL30OrHigher() const763 bool GpuInfo::IsCL30OrHigher() const {
764   if (!IsApiOpenCl()) {
765     return false;
766   }
767   return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
768          opencl_info.cl_version != OpenClVersion::kCl2_1 &&
769          opencl_info.cl_version != OpenClVersion::kCl2_2;
770 }
771 
772 }  // namespace gpu
773 }  // namespace tflite
774