1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
17
18 #include <map>
19 #include <string>
20
21 #include "absl/strings/ascii.h"
22
23 namespace tflite {
24 namespace gpu {
25 namespace {
26
GetGpuVendor(const std::string & gpu_description)27 GpuVendor GetGpuVendor(const std::string& gpu_description) {
28 const std::map<std::string, GpuVendor> kMapping = {
29 {"adreno", GpuVendor::kQualcomm},
30 {"apple", GpuVendor::kApple},
31 {"qualcomm", GpuVendor::kQualcomm},
32 {"mali", GpuVendor::kMali},
33 {"powervr", GpuVendor::kPowerVR},
34 {"advanced micro devices", GpuVendor::kAMD},
35 {"intel", GpuVendor::kIntel},
36 {"nvidia", GpuVendor::kNvidia},
37 {"amd", GpuVendor::kAMD},
38 {"power", GpuVendor::kPowerVR},
39 };
40 for (const auto& v : kMapping) {
41 if (gpu_description.find(v.first) != std::string::npos) {
42 return v.second;
43 }
44 }
45 return GpuVendor::kUnknown;
46 }
47
GetAdrenoGpuVersion(const std::string & gpu_description)48 AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
49 const std::map<std::string, AdrenoGpu> kMapping = {
50 // Adreno 6xx series
51 {"685", AdrenoGpu::kAdreno685},
52 {"680", AdrenoGpu::kAdreno680},
53 {"675", AdrenoGpu::kAdreno675},
54 {"650", AdrenoGpu::kAdreno650},
55 {"640", AdrenoGpu::kAdreno640},
56 {"630", AdrenoGpu::kAdreno630},
57 {"620", AdrenoGpu::kAdreno620},
58 {"618", AdrenoGpu::kAdreno618},
59 {"616", AdrenoGpu::kAdreno616},
60 {"615", AdrenoGpu::kAdreno615},
61 {"612", AdrenoGpu::kAdreno612},
62 {"610", AdrenoGpu::kAdreno610},
63 {"605", AdrenoGpu::kAdreno605},
64 // Adreno 5xx series
65 {"540", AdrenoGpu::kAdreno540},
66 {"530", AdrenoGpu::kAdreno530},
67 {"512", AdrenoGpu::kAdreno512},
68 {"510", AdrenoGpu::kAdreno510},
69 {"509", AdrenoGpu::kAdreno509},
70 {"508", AdrenoGpu::kAdreno508},
71 {"506", AdrenoGpu::kAdreno506},
72 {"505", AdrenoGpu::kAdreno505},
73 {"504", AdrenoGpu::kAdreno504},
74 // Adreno 4xx series
75 {"430", AdrenoGpu::kAdreno430},
76 {"420", AdrenoGpu::kAdreno420},
77 {"418", AdrenoGpu::kAdreno418},
78 {"405", AdrenoGpu::kAdreno405},
79 // Adreno 3xx series
80 {"330", AdrenoGpu::kAdreno330},
81 {"320", AdrenoGpu::kAdreno320},
82 {"308", AdrenoGpu::kAdreno308},
83 {"306", AdrenoGpu::kAdreno306},
84 {"305", AdrenoGpu::kAdreno305},
85 {"304", AdrenoGpu::kAdreno304},
86 // Adreno 2xx series
87 {"225", AdrenoGpu::kAdreno225},
88 {"220", AdrenoGpu::kAdreno220},
89 {"205", AdrenoGpu::kAdreno205},
90 {"203", AdrenoGpu::kAdreno203},
91 {"200", AdrenoGpu::kAdreno200},
92 // Adreno 1xx series
93 {"130", AdrenoGpu::kAdreno130},
94 {"120", AdrenoGpu::kAdreno120},
95 };
96
97 for (const auto& v : kMapping) {
98 if (gpu_description.find(v.first) != std::string::npos) {
99 return v.second;
100 }
101 }
102 return AdrenoGpu::kUnknown;
103 }
104
GetMaliGpuVersion(const std::string & gpu_description)105 MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
106 const std::map<std::string, MaliGpu> kMapping = {
107 {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
108 {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
109 {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
110 {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
111 {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
112 {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
113 {"g31", MaliGpu::kG31}, {"g51", MaliGpu::kG51},
114 {"g71", MaliGpu::kG71}, {"g52", MaliGpu::kG52},
115 {"g72", MaliGpu::kG72}, {"g76", MaliGpu::kG76},
116 {"g57", MaliGpu::kG57}, {"g77", MaliGpu::kG77},
117 {"g68", MaliGpu::kG68}, {"g78", MaliGpu::kG78},
118 };
119 for (const auto& v : kMapping) {
120 if (gpu_description.find(v.first) != std::string::npos) {
121 return v.second;
122 }
123 }
124 return MaliGpu::kUnknown;
125 }
126
127 } // namespace
128
AdrenoInfo(const std::string & device_version)129 AdrenoInfo::AdrenoInfo(const std::string& device_version)
130 : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
131
IsAdreno1xx() const132 bool AdrenoInfo::IsAdreno1xx() const {
133 return adreno_gpu == AdrenoGpu::kAdreno120 ||
134 adreno_gpu == AdrenoGpu::kAdreno130;
135 }
136
IsAdreno2xx() const137 bool AdrenoInfo::IsAdreno2xx() const {
138 return adreno_gpu == AdrenoGpu::kAdreno200 ||
139 adreno_gpu == AdrenoGpu::kAdreno203 ||
140 adreno_gpu == AdrenoGpu::kAdreno205 ||
141 adreno_gpu == AdrenoGpu::kAdreno220 ||
142 adreno_gpu == AdrenoGpu::kAdreno225;
143 }
144
IsAdreno3xx() const145 bool AdrenoInfo::IsAdreno3xx() const {
146 return adreno_gpu == AdrenoGpu::kAdreno304 ||
147 adreno_gpu == AdrenoGpu::kAdreno305 ||
148 adreno_gpu == AdrenoGpu::kAdreno306 ||
149 adreno_gpu == AdrenoGpu::kAdreno308 ||
150 adreno_gpu == AdrenoGpu::kAdreno320 ||
151 adreno_gpu == AdrenoGpu::kAdreno330;
152 }
153
IsAdreno4xx() const154 bool AdrenoInfo::IsAdreno4xx() const {
155 return adreno_gpu == AdrenoGpu::kAdreno405 ||
156 adreno_gpu == AdrenoGpu::kAdreno418 ||
157 adreno_gpu == AdrenoGpu::kAdreno420 ||
158 adreno_gpu == AdrenoGpu::kAdreno430;
159 }
160
IsAdreno5xx() const161 bool AdrenoInfo::IsAdreno5xx() const {
162 return adreno_gpu == AdrenoGpu::kAdreno504 ||
163 adreno_gpu == AdrenoGpu::kAdreno505 ||
164 adreno_gpu == AdrenoGpu::kAdreno506 ||
165 adreno_gpu == AdrenoGpu::kAdreno508 ||
166 adreno_gpu == AdrenoGpu::kAdreno509 ||
167 adreno_gpu == AdrenoGpu::kAdreno510 ||
168 adreno_gpu == AdrenoGpu::kAdreno512 ||
169 adreno_gpu == AdrenoGpu::kAdreno530 ||
170 adreno_gpu == AdrenoGpu::kAdreno540;
171 }
172
IsAdreno6xx() const173 bool AdrenoInfo::IsAdreno6xx() const {
174 return adreno_gpu == AdrenoGpu::kAdreno605 ||
175 adreno_gpu == AdrenoGpu::kAdreno610 ||
176 adreno_gpu == AdrenoGpu::kAdreno612 ||
177 adreno_gpu == AdrenoGpu::kAdreno615 ||
178 adreno_gpu == AdrenoGpu::kAdreno616 ||
179 adreno_gpu == AdrenoGpu::kAdreno618 ||
180 adreno_gpu == AdrenoGpu::kAdreno620 ||
181 adreno_gpu == AdrenoGpu::kAdreno630 ||
182 adreno_gpu == AdrenoGpu::kAdreno640 ||
183 adreno_gpu == AdrenoGpu::kAdreno650 ||
184 adreno_gpu == AdrenoGpu::kAdreno675 ||
185 adreno_gpu == AdrenoGpu::kAdreno680 ||
186 adreno_gpu == AdrenoGpu::kAdreno685;
187 }
188
IsAdreno6xxOrHigher() const189 bool AdrenoInfo::IsAdreno6xxOrHigher() const {
190 return !compiler_bugs_in_a6xx && IsAdreno6xx();
191 }
192
GetMaximumWavesCount() const193 int AdrenoInfo::GetMaximumWavesCount() const {
194 if (IsAdreno6xx()) {
195 if (adreno_gpu == AdrenoGpu::kAdreno640) {
196 return 30;
197 } else {
198 return 16;
199 }
200 } else {
201 // all other versions not supported
202 return 1;
203 }
204 }
205
GetRegisterMemorySizePerComputeUnit() const206 int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
207 if (IsAdreno6xx()) {
208 if (adreno_gpu == AdrenoGpu::kAdreno640) {
209 return 128 * 144 * 16;
210 } else if (adreno_gpu == AdrenoGpu::kAdreno650 ||
211 adreno_gpu == AdrenoGpu::kAdreno620) {
212 return 128 * 64 * 16;
213 } else {
214 return 128 * 96 * 16;
215 }
216 } else {
217 // all other versions not supported
218 return 1;
219 }
220 }
221
GetMaximumWavesCount(int register_footprint_per_tread,bool full_wave) const222 int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
223 bool full_wave) const {
224 const int register_usage_per_wave =
225 GetWaveSize(full_wave) * register_footprint_per_tread;
226 const int possible_waves_count =
227 GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
228 return std::min(possible_waves_count, GetMaximumWavesCount());
229 }
230
GetWaveSize(bool full_wave) const231 int AdrenoInfo::GetWaveSize(bool full_wave) const {
232 if (IsAdreno6xx()) {
233 return full_wave ? 128 : 64;
234 } else if (IsAdreno5xx() || IsAdreno4xx()) {
235 return full_wave ? 64 : 32;
236 } else {
237 // all other versions not supported
238 return 1;
239 }
240 }
241
AppleInfo(const std::string & gpu_description)242 AppleInfo::AppleInfo(const std::string& gpu_description) {
243 const std::map<std::string, AppleGpu> kMapping = {
244 {"apple a7 gpu", AppleGpu::kA7}, {"apple a8 gpu", AppleGpu::kA8},
245 {"apple a8x gpu", AppleGpu::kA8X}, {"apple a9 gpu", AppleGpu::kA9},
246 {"apple a9x gpu", AppleGpu::kA9X}, {"apple a10 gpu", AppleGpu::kA10},
247 {"apple a10x gpu", AppleGpu::kA10X}, {"apple a11 gpu", AppleGpu::kA11},
248 {"apple a12 gpu", AppleGpu::kA12}, {"apple a12x gpu", AppleGpu::kA12X},
249 {"apple a12z gpu", AppleGpu::kA12Z}, {"apple a13 gpu", AppleGpu::kA13},
250 {"apple a14 gpu", AppleGpu::kA14},
251 };
252 auto it = kMapping.find(gpu_description);
253 if (it != kMapping.end()) {
254 gpu_type = it->second;
255 } else {
256 gpu_type = AppleGpu::kUnknown;
257 }
258 }
259
IsLocalMemoryPreferredOverGlobal() const260 bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
261 return gpu_type == AppleGpu::kA7 || gpu_type == AppleGpu::kA8 ||
262 gpu_type == AppleGpu::kA8X;
263 }
264
IsBionic() const265 bool AppleInfo::IsBionic() const {
266 return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
267 gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
268 gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14;
269 }
270
IsRoundToNearestSupported() const271 bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
272
GetComputeUnitsCount() const273 int AppleInfo::GetComputeUnitsCount() const {
274 switch (gpu_type) {
275 case AppleGpu::kA7:
276 return 4;
277 case AppleGpu::kA8:
278 return 4;
279 case AppleGpu::kA8X:
280 return 8;
281 case AppleGpu::kA9:
282 return 6;
283 case AppleGpu::kA9X:
284 return 12;
285 case AppleGpu::kA10:
286 return 6;
287 case AppleGpu::kA10X:
288 return 12;
289 case AppleGpu::kA11:
290 return 3;
291 case AppleGpu::kA12:
292 return 4;
293 case AppleGpu::kA12X:
294 return 7;
295 case AppleGpu::kA12Z:
296 return 8;
297 case AppleGpu::kA13:
298 return 4;
299 case AppleGpu::kA14:
300 return 4;
301 case AppleGpu::kUnknown:
302 return 1;
303 }
304 }
305
MaliInfo(const std::string & gpu_description)306 MaliInfo::MaliInfo(const std::string& gpu_description)
307 : gpu_version(GetMaliGpuVersion(gpu_description)) {}
308
IsMaliT6xx() const309 bool MaliInfo::IsMaliT6xx() const {
310 return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
311 gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
312 gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
313 }
314
IsMaliT7xx() const315 bool MaliInfo::IsMaliT7xx() const {
316 return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
317 }
318
IsMaliT8xx() const319 bool MaliInfo::IsMaliT8xx() const {
320 return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
321 gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
322 }
323
IsMidgard() const324 bool MaliInfo::IsMidgard() const {
325 return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
326 }
327
IsBifrostGen1() const328 bool MaliInfo::IsBifrostGen1() const {
329 return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
330 gpu_version == MaliGpu::kG71;
331 }
332
IsBifrostGen2() const333 bool MaliInfo::IsBifrostGen2() const {
334 return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
335 }
336
IsBifrostGen3() const337 bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
338
IsBifrost() const339 bool MaliInfo::IsBifrost() const {
340 return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
341 }
342
IsValhall() const343 bool MaliInfo::IsValhall() const {
344 return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 ||
345 gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
346 }
347
GetGpuInfoFromDeviceDescription(const std::string & gpu_description,GpuApi gpu_api,GpuInfo * gpu_info)348 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
349 GpuApi gpu_api, GpuInfo* gpu_info) {
350 gpu_info->gpu_api = gpu_api;
351 std::string lowered = gpu_description;
352 absl::AsciiStrToLower(&lowered);
353 gpu_info->vendor = GetGpuVendor(lowered);
354 if (gpu_info->IsAdreno()) {
355 gpu_info->adreno_info = AdrenoInfo(lowered);
356 } else if (gpu_info->IsApple()) {
357 gpu_info->apple_info = AppleInfo(lowered);
358 gpu_info->supported_subgroup_sizes = {32};
359 } else if (gpu_info->IsMali()) {
360 gpu_info->mali_info = MaliInfo(lowered);
361 }
362 }
363
OpenClVersionToString(OpenClVersion version)364 std::string OpenClVersionToString(OpenClVersion version) {
365 switch (version) {
366 case OpenClVersion::kCl1_0:
367 return "1.0";
368 case OpenClVersion::kCl1_1:
369 return "1.1";
370 case OpenClVersion::kCl1_2:
371 return "1.2";
372 case OpenClVersion::kCl2_0:
373 return "2.0";
374 case OpenClVersion::kCl2_1:
375 return "2.1";
376 case OpenClVersion::kCl2_2:
377 return "2.2";
378 case OpenClVersion::kCl3_0:
379 return "3.0";
380 default:
381 return "Unknown OpenCL version";
382 }
383 }
384
IsAdreno() const385 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
386
IsApple() const387 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
388
IsMali() const389 bool GpuInfo::IsMali() const { return vendor == GpuVendor::kMali; }
390
IsPowerVR() const391 bool GpuInfo::IsPowerVR() const { return vendor == GpuVendor::kPowerVR; }
392
IsNvidia() const393 bool GpuInfo::IsNvidia() const { return vendor == GpuVendor::kNvidia; }
394
IsAMD() const395 bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
396
IsIntel() const397 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
398
IsRoundToNearestSupported() const399 bool GpuInfo::IsRoundToNearestSupported() const {
400 if (IsApiOpenCl()) {
401 return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
402 }
403 if (IsApple()) {
404 return apple_info.IsRoundToNearestSupported();
405 }
406 return true;
407 }
408
SupportsFP16() const409 bool GpuInfo::SupportsFP16() const {
410 if (IsApiOpenCl()) {
411 return opencl_info.supports_fp16;
412 }
413 return true;
414 }
415
SupportsTextureArray() const416 bool GpuInfo::SupportsTextureArray() const {
417 if (!SupportsImages()) {
418 return false;
419 }
420 if (IsApiOpenCl()) {
421 return opencl_info.cl_version >= OpenClVersion::kCl1_2;
422 }
423 return true;
424 }
425
SupportsImageBuffer() const426 bool GpuInfo::SupportsImageBuffer() const {
427 if (!SupportsImages()) {
428 return false;
429 }
430 if (IsApiOpenCl()) {
431 return opencl_info.cl_version >= OpenClVersion::kCl1_2;
432 }
433 return true;
434 }
435
SupportsImage3D() const436 bool GpuInfo::SupportsImage3D() const {
437 if (!SupportsImages()) {
438 return false;
439 }
440 if (IsApiOpenCl()) {
441 if (IsMali() && mali_info.IsMidgard()) {
442 // On Mali T880 read_imageh doesn't compile with image3d_t
443 return false;
444 }
445 return opencl_info.supports_image3d_writes;
446 }
447 return true;
448 }
449
SupportsImages() const450 bool GpuInfo::SupportsImages() const {
451 if (IsApiOpenCl()) {
452 return opencl_info.supports_images;
453 }
454 return true;
455 }
456
IsWaveSizeEqualTo32() const457 bool GpuInfo::IsWaveSizeEqualTo32() const {
458 return supported_subgroup_sizes.size() == 1 &&
459 supported_subgroup_sizes[0] == 32;
460 }
461
SupportsExtension(const std::string & extension) const462 bool GpuInfo::SupportsExtension(const std::string& extension) const {
463 const std::vector<std::string>* extensions = nullptr;
464 if (IsApiOpenGl()) {
465 extensions = &opengl_info.extensions;
466 } else if (IsApiVulkan()) {
467 extensions = &vulkan_info.extensions;
468 } else if (IsApiOpenCl()) {
469 extensions = &opencl_info.extensions;
470 }
471 if (!extensions) {
472 return false;
473 }
474 for (const auto& ext : *extensions) {
475 if (ext == extension) {
476 return true;
477 }
478 }
479 return false;
480 }
481
SupportsSubGroupWithSize(int sub_group_size) const482 bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
483 for (auto subgroup_size : supported_subgroup_sizes) {
484 if (sub_group_size == subgroup_size) {
485 return true;
486 }
487 }
488 return false;
489 }
490
SupportsFloatImage2D(DataType data_type,int channels) const491 bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
492 if (IsApiOpenCl()) {
493 if (channels == 1) {
494 return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d
495 : opencl_info.supports_r_f16_tex2d;
496 } else if (channels == 2) {
497 return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d
498 : opencl_info.supports_rg_f16_tex2d;
499 } else if (channels == 3) {
500 return data_type == DataType::FLOAT32
501 ? opencl_info.supports_rgb_f32_tex2d
502 : opencl_info.supports_rgb_f16_tex2d;
503 } else if (channels == 4) {
504 return data_type == DataType::FLOAT32
505 ? opencl_info.supports_rgba_f32_tex2d
506 : opencl_info.supports_rgba_f16_tex2d;
507 } else {
508 return false;
509 }
510 }
511 return false;
512 }
513
GetComputeUnitsCount() const514 int GpuInfo::GetComputeUnitsCount() const {
515 if (IsApiOpenCl()) {
516 return opencl_info.compute_units_count;
517 }
518 if (IsApple()) {
519 return apple_info.GetComputeUnitsCount();
520 }
521 return 1;
522 }
523
GetMaxWorkGroupSizeForX() const524 int GpuInfo::GetMaxWorkGroupSizeForX() const {
525 if (IsApiOpenGl()) {
526 return opengl_info.max_compute_work_group_size_x;
527 }
528 if (IsApiVulkan()) {
529 return vulkan_info.max_compute_work_group_size_x;
530 }
531 if (IsApiOpenCl()) {
532 return opencl_info.max_work_group_size_x;
533 }
534 if (IsApiMetal()) {
535 return metal_info.max_work_group_size_x;
536 }
537 return 256;
538 }
539
GetMaxWorkGroupSizeForY() const540 int GpuInfo::GetMaxWorkGroupSizeForY() const {
541 if (IsApiOpenGl()) {
542 return opengl_info.max_compute_work_group_size_y;
543 }
544 if (IsApiVulkan()) {
545 return vulkan_info.max_compute_work_group_size_y;
546 }
547 if (IsApiOpenCl()) {
548 return opencl_info.max_work_group_size_y;
549 }
550 if (IsApiMetal()) {
551 return metal_info.max_work_group_size_y;
552 }
553 return 256;
554 }
555
GetMaxWorkGroupSizeForZ() const556 int GpuInfo::GetMaxWorkGroupSizeForZ() const {
557 if (IsApiOpenGl()) {
558 return opengl_info.max_compute_work_group_size_z;
559 }
560 if (IsApiVulkan()) {
561 return vulkan_info.max_compute_work_group_size_z;
562 }
563 if (IsApiOpenCl()) {
564 return opencl_info.max_work_group_size_z;
565 }
566 if (IsApiMetal()) {
567 return metal_info.max_work_group_size_z;
568 }
569 return 64;
570 }
571
GetMaxWorkGroupTotalSize() const572 int GpuInfo::GetMaxWorkGroupTotalSize() const {
573 if (IsApiOpenGl()) {
574 return opengl_info.max_work_group_invocations;
575 }
576 if (IsApiVulkan()) {
577 return vulkan_info.max_compute_work_group_invocations;
578 }
579 if (IsApiOpenCl()) {
580 return opencl_info.max_work_group_total_size;
581 }
582 if (IsApiMetal()) {
583 int max_size = metal_info.max_work_group_size_x;
584 max_size = std::max(max_size, metal_info.max_work_group_size_y);
585 max_size = std::max(max_size, metal_info.max_work_group_size_z);
586 return max_size;
587 }
588 return 256;
589 }
590
GetMaxImage2DWidth() const591 uint64_t GpuInfo::GetMaxImage2DWidth() const {
592 if (IsApiOpenGl()) {
593 return opengl_info.max_texture_size;
594 }
595 if (IsApiVulkan()) {
596 return vulkan_info.max_image_dimension_2d;
597 }
598 if (IsApiOpenCl()) {
599 return opencl_info.image2d_max_width;
600 }
601 return 2048;
602 }
603
GetMaxImage2DHeight() const604 uint64_t GpuInfo::GetMaxImage2DHeight() const {
605 if (IsApiOpenGl()) {
606 return opengl_info.max_texture_size;
607 }
608 if (IsApiVulkan()) {
609 return vulkan_info.max_image_dimension_2d;
610 }
611 if (IsApiOpenCl()) {
612 return opencl_info.image2d_max_height;
613 }
614 return 2048;
615 }
616
GetMaxImage2DArrayLayers() const617 uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
618 if (IsApiOpenGl()) {
619 return opengl_info.max_array_texture_layers;
620 }
621 if (IsApiVulkan()) {
622 return vulkan_info.max_image_array_layers;
623 }
624 if (IsApiOpenCl()) {
625 return opencl_info.image_array_max_layers;
626 }
627 return 256;
628 }
629
GetMaxImage3DWidth() const630 uint64_t GpuInfo::GetMaxImage3DWidth() const {
631 if (IsApiOpenCl()) {
632 return opencl_info.image3d_max_width;
633 }
634 return 256;
635 }
636
GetMaxImage3DHeight() const637 uint64_t GpuInfo::GetMaxImage3DHeight() const {
638 if (IsApiOpenCl()) {
639 return opencl_info.image3d_max_height;
640 }
641 return 256;
642 }
643
GetMaxImage3DDepth() const644 uint64_t GpuInfo::GetMaxImage3DDepth() const {
645 if (IsApiOpenCl()) {
646 return opencl_info.image3d_max_depth;
647 }
648 return 256;
649 }
650
GetMaxBufferSize() const651 uint64_t GpuInfo::GetMaxBufferSize() const {
652 if (IsApiOpenCl()) {
653 return opencl_info.buffer_max_size;
654 } else if (IsApiMetal()) {
655 return metal_info.buffer_max_size;
656 }
657 return 128 * 1024 * 1024;
658 }
659
GetMaxImageBufferWidth() const660 uint64_t GpuInfo::GetMaxImageBufferWidth() const {
661 if (IsApiOpenCl()) {
662 return opencl_info.image_buffer_max_size;
663 }
664 return 64 * 1024;
665 }
666
GetMaxImageArguments() const667 int GpuInfo::GetMaxImageArguments() const {
668 if (IsApiOpenGl()) {
669 return opengl_info.max_image_units;
670 }
671 if (IsApiVulkan()) {
672 return vulkan_info.max_per_stage_descriptor_sampled_images;
673 }
674 if (IsApiMetal()) {
675 return 32;
676 }
677 if (IsApiOpenCl()) {
678 return 128;
679 }
680 return 1;
681 }
682
IsApiOpenGl() const683 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
684
IsApiOpenGl31OrAbove() const685 bool GpuInfo::IsApiOpenGl31OrAbove() const {
686 if (!IsApiOpenGl()) {
687 return false;
688 }
689 return (opengl_info.major_version == 3 && opengl_info.minor_version >= 1) ||
690 opengl_info.major_version > 3;
691 }
692
IsApiVulkan() const693 bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
694
IsApiMetal() const695 bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
696
IsApiOpenCl() const697 bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
698
IsCL20OrHigher() const699 bool GpuInfo::IsCL20OrHigher() const {
700 if (!IsApiOpenCl()) {
701 return false;
702 }
703 return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
704 opencl_info.cl_version != OpenClVersion::kCl1_1 &&
705 opencl_info.cl_version != OpenClVersion::kCl1_2;
706 }
707
IsCL30OrHigher() const708 bool GpuInfo::IsCL30OrHigher() const {
709 if (!IsApiOpenCl()) {
710 return false;
711 }
712 return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
713 opencl_info.cl_version != OpenClVersion::kCl2_1 &&
714 opencl_info.cl_version != OpenClVersion::kCl2_2;
715 }
716
717 } // namespace gpu
718 } // namespace tflite
719