1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <string>
21
22 #include "absl/strings/ascii.h"
23
24 namespace tflite {
25 namespace gpu {
26 namespace {
27
GetGpuVendor(const std::string & gpu_description)28 GpuVendor GetGpuVendor(const std::string& gpu_description) {
29 const std::map<std::string, GpuVendor> kMapping = {
30 {"adreno", GpuVendor::kQualcomm},
31 {"apple", GpuVendor::kApple},
32 {"qualcomm", GpuVendor::kQualcomm},
33 {"mali", GpuVendor::kMali},
34 {"powervr", GpuVendor::kPowerVR},
35 {"advanced micro devices", GpuVendor::kAMD},
36 {"intel", GpuVendor::kIntel},
37 {"nvidia", GpuVendor::kNvidia},
38 {"amd", GpuVendor::kAMD},
39 {"power", GpuVendor::kPowerVR},
40 };
41 for (const auto& v : kMapping) {
42 if (gpu_description.find(v.first) != std::string::npos) {
43 return v.second;
44 }
45 }
46 return GpuVendor::kUnknown;
47 }
48
GetAdrenoGpuVersion(const std::string & gpu_description)49 AdrenoGpu GetAdrenoGpuVersion(const std::string& gpu_description) {
50 const std::map<std::string, AdrenoGpu> kMapping = {
51 // Adreno 6xx series
52 {"685", AdrenoGpu::kAdreno685},
53 {"680", AdrenoGpu::kAdreno680},
54 {"675", AdrenoGpu::kAdreno675},
55 {"660", AdrenoGpu::kAdreno660},
56 {"650", AdrenoGpu::kAdreno650},
57 {"640", AdrenoGpu::kAdreno640},
58 {"630", AdrenoGpu::kAdreno630},
59 {"620", AdrenoGpu::kAdreno620},
60 {"618", AdrenoGpu::kAdreno618},
61 {"616", AdrenoGpu::kAdreno616},
62 {"615", AdrenoGpu::kAdreno615},
63 {"612", AdrenoGpu::kAdreno612},
64 {"610", AdrenoGpu::kAdreno610},
65 {"605", AdrenoGpu::kAdreno605},
66 // Adreno 5xx series
67 {"540", AdrenoGpu::kAdreno540},
68 {"530", AdrenoGpu::kAdreno530},
69 {"512", AdrenoGpu::kAdreno512},
70 {"510", AdrenoGpu::kAdreno510},
71 {"509", AdrenoGpu::kAdreno509},
72 {"508", AdrenoGpu::kAdreno508},
73 {"506", AdrenoGpu::kAdreno506},
74 {"505", AdrenoGpu::kAdreno505},
75 {"504", AdrenoGpu::kAdreno504},
76 // Adreno 4xx series
77 {"430", AdrenoGpu::kAdreno430},
78 {"420", AdrenoGpu::kAdreno420},
79 {"418", AdrenoGpu::kAdreno418},
80 {"405", AdrenoGpu::kAdreno405},
81 // Adreno 3xx series
82 {"330", AdrenoGpu::kAdreno330},
83 {"320", AdrenoGpu::kAdreno320},
84 {"308", AdrenoGpu::kAdreno308},
85 {"306", AdrenoGpu::kAdreno306},
86 {"305", AdrenoGpu::kAdreno305},
87 {"304", AdrenoGpu::kAdreno304},
88 // Adreno 2xx series
89 {"225", AdrenoGpu::kAdreno225},
90 {"220", AdrenoGpu::kAdreno220},
91 {"205", AdrenoGpu::kAdreno205},
92 {"203", AdrenoGpu::kAdreno203},
93 {"200", AdrenoGpu::kAdreno200},
94 // Adreno 1xx series
95 {"130", AdrenoGpu::kAdreno130},
96 {"120", AdrenoGpu::kAdreno120},
97 };
98
99 for (const auto& v : kMapping) {
100 if (gpu_description.find(v.first) != std::string::npos) {
101 return v.second;
102 }
103 }
104 return AdrenoGpu::kUnknown;
105 }
106
GetMaliGpuVersion(const std::string & gpu_description)107 MaliGpu GetMaliGpuVersion(const std::string& gpu_description) {
108 const std::map<std::string, MaliGpu> kMapping = {
109 {"t604", MaliGpu::kT604}, {"t622", MaliGpu::kT622},
110 {"t624", MaliGpu::kT624}, {"t628", MaliGpu::kT628},
111 {"t658", MaliGpu::kT658}, {"t678", MaliGpu::kT678},
112 {"t720", MaliGpu::kT720}, {"t760", MaliGpu::kT760},
113 {"t820", MaliGpu::kT820}, {"t830", MaliGpu::kT830},
114 {"t860", MaliGpu::kT860}, {"t880", MaliGpu::kT880},
115 {"g31", MaliGpu::kG31}, {"g51", MaliGpu::kG51},
116 {"g71", MaliGpu::kG71}, {"g52", MaliGpu::kG52},
117 {"g72", MaliGpu::kG72}, {"g76", MaliGpu::kG76},
118 {"g57", MaliGpu::kG57}, {"g77", MaliGpu::kG77},
119 {"g68", MaliGpu::kG68}, {"g78", MaliGpu::kG78},
120 };
121 for (const auto& v : kMapping) {
122 if (gpu_description.find(v.first) != std::string::npos) {
123 return v.second;
124 }
125 }
126 return MaliGpu::kUnknown;
127 }
128
129 } // namespace
130
AdrenoInfo(const std::string & device_version)131 AdrenoInfo::AdrenoInfo(const std::string& device_version)
132 : adreno_gpu(GetAdrenoGpuVersion(device_version)) {}
133
IsAdreno1xx() const134 bool AdrenoInfo::IsAdreno1xx() const {
135 return adreno_gpu == AdrenoGpu::kAdreno120 ||
136 adreno_gpu == AdrenoGpu::kAdreno130;
137 }
138
IsAdreno2xx() const139 bool AdrenoInfo::IsAdreno2xx() const {
140 return adreno_gpu == AdrenoGpu::kAdreno200 ||
141 adreno_gpu == AdrenoGpu::kAdreno203 ||
142 adreno_gpu == AdrenoGpu::kAdreno205 ||
143 adreno_gpu == AdrenoGpu::kAdreno220 ||
144 adreno_gpu == AdrenoGpu::kAdreno225;
145 }
146
IsAdreno3xx() const147 bool AdrenoInfo::IsAdreno3xx() const {
148 return adreno_gpu == AdrenoGpu::kAdreno304 ||
149 adreno_gpu == AdrenoGpu::kAdreno305 ||
150 adreno_gpu == AdrenoGpu::kAdreno306 ||
151 adreno_gpu == AdrenoGpu::kAdreno308 ||
152 adreno_gpu == AdrenoGpu::kAdreno320 ||
153 adreno_gpu == AdrenoGpu::kAdreno330;
154 }
155
IsAdreno4xx() const156 bool AdrenoInfo::IsAdreno4xx() const {
157 return adreno_gpu == AdrenoGpu::kAdreno405 ||
158 adreno_gpu == AdrenoGpu::kAdreno418 ||
159 adreno_gpu == AdrenoGpu::kAdreno420 ||
160 adreno_gpu == AdrenoGpu::kAdreno430;
161 }
162
IsAdreno5xx() const163 bool AdrenoInfo::IsAdreno5xx() const {
164 return adreno_gpu == AdrenoGpu::kAdreno504 ||
165 adreno_gpu == AdrenoGpu::kAdreno505 ||
166 adreno_gpu == AdrenoGpu::kAdreno506 ||
167 adreno_gpu == AdrenoGpu::kAdreno508 ||
168 adreno_gpu == AdrenoGpu::kAdreno509 ||
169 adreno_gpu == AdrenoGpu::kAdreno510 ||
170 adreno_gpu == AdrenoGpu::kAdreno512 ||
171 adreno_gpu == AdrenoGpu::kAdreno530 ||
172 adreno_gpu == AdrenoGpu::kAdreno540;
173 }
174
IsAdreno6xx() const175 bool AdrenoInfo::IsAdreno6xx() const {
176 return adreno_gpu == AdrenoGpu::kAdreno605 ||
177 adreno_gpu == AdrenoGpu::kAdreno610 ||
178 adreno_gpu == AdrenoGpu::kAdreno612 ||
179 adreno_gpu == AdrenoGpu::kAdreno615 ||
180 adreno_gpu == AdrenoGpu::kAdreno616 ||
181 adreno_gpu == AdrenoGpu::kAdreno618 ||
182 adreno_gpu == AdrenoGpu::kAdreno620 ||
183 adreno_gpu == AdrenoGpu::kAdreno630 ||
184 adreno_gpu == AdrenoGpu::kAdreno640 ||
185 adreno_gpu == AdrenoGpu::kAdreno650 ||
186 adreno_gpu == AdrenoGpu::kAdreno660 ||
187 adreno_gpu == AdrenoGpu::kAdreno675 ||
188 adreno_gpu == AdrenoGpu::kAdreno680 ||
189 adreno_gpu == AdrenoGpu::kAdreno685;
190 }
191
IsAdreno6xxOrHigher() const192 bool AdrenoInfo::IsAdreno6xxOrHigher() const {
193 return !compiler_bugs_in_a6xx && IsAdreno6xx();
194 }
195
GetMaximumWavesCount() const196 int AdrenoInfo::GetMaximumWavesCount() const {
197 if (IsAdreno6xx()) {
198 if (adreno_gpu == AdrenoGpu::kAdreno640) {
199 return 30;
200 } else {
201 return 16;
202 }
203 } else {
204 // all other versions not supported
205 return 1;
206 }
207 }
208
GetRegisterMemorySizePerComputeUnit() const209 int AdrenoInfo::GetRegisterMemorySizePerComputeUnit() const {
210 if (IsAdreno6xx()) {
211 if (adreno_gpu == AdrenoGpu::kAdreno640) {
212 return 128 * 144 * 16;
213 } else if (adreno_gpu == AdrenoGpu::kAdreno620 ||
214 adreno_gpu == AdrenoGpu::kAdreno650 ||
215 adreno_gpu == AdrenoGpu::kAdreno660) {
216 return 128 * 64 * 16;
217 } else {
218 return 128 * 96 * 16;
219 }
220 } else {
221 // all other versions not supported
222 return 1;
223 }
224 }
225
GetMaximumWavesCount(int register_footprint_per_tread,bool full_wave) const226 int AdrenoInfo::GetMaximumWavesCount(int register_footprint_per_tread,
227 bool full_wave) const {
228 const int register_usage_per_wave =
229 GetWaveSize(full_wave) * register_footprint_per_tread;
230 const int possible_waves_count =
231 GetRegisterMemorySizePerComputeUnit() / register_usage_per_wave;
232 return std::min(possible_waves_count, GetMaximumWavesCount());
233 }
234
GetWaveSize(bool full_wave) const235 int AdrenoInfo::GetWaveSize(bool full_wave) const {
236 if (IsAdreno6xx()) {
237 return full_wave ? 128 : 64;
238 } else if (IsAdreno5xx() || IsAdreno4xx()) {
239 return full_wave ? 64 : 32;
240 } else {
241 // all other versions not supported
242 return 1;
243 }
244 }
245
AppleInfo(const std::string & gpu_description)246 AppleInfo::AppleInfo(const std::string& gpu_description) {
247 const std::map<std::string, AppleGpu> kMapping = {
248 {"apple a7 gpu", AppleGpu::kA7}, {"apple a8 gpu", AppleGpu::kA8},
249 {"apple a8x gpu", AppleGpu::kA8X}, {"apple a9 gpu", AppleGpu::kA9},
250 {"apple a9x gpu", AppleGpu::kA9X}, {"apple a10 gpu", AppleGpu::kA10},
251 {"apple a10x gpu", AppleGpu::kA10X}, {"apple a11 gpu", AppleGpu::kA11},
252 {"apple a12 gpu", AppleGpu::kA12}, {"apple a12x gpu", AppleGpu::kA12X},
253 {"apple a12z gpu", AppleGpu::kA12Z}, {"apple a13 gpu", AppleGpu::kA13},
254 {"apple a14 gpu", AppleGpu::kA14},
255 };
256 auto it = kMapping.find(gpu_description);
257 if (it != kMapping.end()) {
258 gpu_type = it->second;
259 } else {
260 gpu_type = AppleGpu::kUnknown;
261 }
262 }
263
IsLocalMemoryPreferredOverGlobal() const264 bool AppleInfo::IsLocalMemoryPreferredOverGlobal() const {
265 return gpu_type == AppleGpu::kA7 || gpu_type == AppleGpu::kA8 ||
266 gpu_type == AppleGpu::kA8X;
267 }
268
IsBionic() const269 bool AppleInfo::IsBionic() const {
270 return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 ||
271 gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z ||
272 gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14;
273 }
274
IsRoundToNearestSupported() const275 bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); }
276
GetComputeUnitsCount() const277 int AppleInfo::GetComputeUnitsCount() const {
278 switch (gpu_type) {
279 case AppleGpu::kA7:
280 return 4;
281 case AppleGpu::kA8:
282 return 4;
283 case AppleGpu::kA8X:
284 return 8;
285 case AppleGpu::kA9:
286 return 6;
287 case AppleGpu::kA9X:
288 return 12;
289 case AppleGpu::kA10:
290 return 6;
291 case AppleGpu::kA10X:
292 return 12;
293 case AppleGpu::kA11:
294 return 3;
295 case AppleGpu::kA12:
296 return 4;
297 case AppleGpu::kA12X:
298 return 7;
299 case AppleGpu::kA12Z:
300 return 8;
301 case AppleGpu::kA13:
302 return 4;
303 case AppleGpu::kA14:
304 return 4;
305 case AppleGpu::kUnknown:
306 return 1;
307 }
308 }
309
MaliInfo(const std::string & gpu_description)310 MaliInfo::MaliInfo(const std::string& gpu_description)
311 : gpu_version(GetMaliGpuVersion(gpu_description)) {}
312
IsMaliT6xx() const313 bool MaliInfo::IsMaliT6xx() const {
314 return gpu_version == MaliGpu::kT604 || gpu_version == MaliGpu::kT622 ||
315 gpu_version == MaliGpu::kT624 || gpu_version == MaliGpu::kT628 ||
316 gpu_version == MaliGpu::kT658 || gpu_version == MaliGpu::kT678;
317 }
318
IsMaliT7xx() const319 bool MaliInfo::IsMaliT7xx() const {
320 return gpu_version == MaliGpu::kT720 || gpu_version == MaliGpu::kT760;
321 }
322
IsMaliT8xx() const323 bool MaliInfo::IsMaliT8xx() const {
324 return gpu_version == MaliGpu::kT820 || gpu_version == MaliGpu::kT830 ||
325 gpu_version == MaliGpu::kT860 || gpu_version == MaliGpu::kT880;
326 }
327
IsMidgard() const328 bool MaliInfo::IsMidgard() const {
329 return IsMaliT6xx() || IsMaliT7xx() || IsMaliT8xx();
330 }
331
IsBifrostGen1() const332 bool MaliInfo::IsBifrostGen1() const {
333 return gpu_version == MaliGpu::kG31 || gpu_version == MaliGpu::kG51 ||
334 gpu_version == MaliGpu::kG71;
335 }
336
IsBifrostGen2() const337 bool MaliInfo::IsBifrostGen2() const {
338 return gpu_version == MaliGpu::kG52 || gpu_version == MaliGpu::kG72;
339 }
340
IsBifrostGen3() const341 bool MaliInfo::IsBifrostGen3() const { return gpu_version == MaliGpu::kG76; }
342
IsBifrost() const343 bool MaliInfo::IsBifrost() const {
344 return IsBifrostGen1() || IsBifrostGen2() || IsBifrostGen3();
345 }
346
IsValhall() const347 bool MaliInfo::IsValhall() const {
348 return gpu_version == MaliGpu::kG57 || gpu_version == MaliGpu::kG77 ||
349 gpu_version == MaliGpu::kG68 || gpu_version == MaliGpu::kG78;
350 }
351
GetGpuInfoFromDeviceDescription(const std::string & gpu_description,GpuApi gpu_api,GpuInfo * gpu_info)352 void GetGpuInfoFromDeviceDescription(const std::string& gpu_description,
353 GpuApi gpu_api, GpuInfo* gpu_info) {
354 gpu_info->gpu_api = gpu_api;
355 std::string lowered = gpu_description;
356 absl::AsciiStrToLower(&lowered);
357 gpu_info->vendor = GetGpuVendor(lowered);
358 if (gpu_info->IsAdreno()) {
359 gpu_info->adreno_info = AdrenoInfo(lowered);
360 } else if (gpu_info->IsApple()) {
361 gpu_info->apple_info = AppleInfo(lowered);
362 gpu_info->supported_subgroup_sizes = {32};
363 } else if (gpu_info->IsMali()) {
364 gpu_info->mali_info = MaliInfo(lowered);
365 }
366 }
367
OpenClVersionToString(OpenClVersion version)368 std::string OpenClVersionToString(OpenClVersion version) {
369 switch (version) {
370 case OpenClVersion::kCl1_0:
371 return "1.0";
372 case OpenClVersion::kCl1_1:
373 return "1.1";
374 case OpenClVersion::kCl1_2:
375 return "1.2";
376 case OpenClVersion::kCl2_0:
377 return "2.0";
378 case OpenClVersion::kCl2_1:
379 return "2.1";
380 case OpenClVersion::kCl2_2:
381 return "2.2";
382 case OpenClVersion::kCl3_0:
383 return "3.0";
384 default:
385 return "Unknown OpenCL version";
386 }
387 }
388
IsImage2dFromBufferSupported() const389 bool OpenClInfo::IsImage2dFromBufferSupported() const {
390 if (image_pitch_alignment == 0) {
391 return false;
392 }
393 if (cl_version == OpenClVersion::kCl2_0 ||
394 cl_version == OpenClVersion::kCl2_1 ||
395 cl_version == OpenClVersion::kCl2_2) {
396 return true;
397 }
398 for (const auto& ext : extensions) {
399 if (ext == "cl_khr_image2d_from_buffer") {
400 return true;
401 }
402 }
403 return false;
404 }
405
IsAdreno() const406 bool GpuInfo::IsAdreno() const { return vendor == GpuVendor::kQualcomm; }
407
IsApple() const408 bool GpuInfo::IsApple() const { return vendor == GpuVendor::kApple; }
409
IsMali() const410 bool GpuInfo::IsMali() const { return vendor == GpuVendor::kMali; }
411
IsPowerVR() const412 bool GpuInfo::IsPowerVR() const { return vendor == GpuVendor::kPowerVR; }
413
IsNvidia() const414 bool GpuInfo::IsNvidia() const { return vendor == GpuVendor::kNvidia; }
415
IsAMD() const416 bool GpuInfo::IsAMD() const { return vendor == GpuVendor::kAMD; }
417
IsIntel() const418 bool GpuInfo::IsIntel() const { return vendor == GpuVendor::kIntel; }
419
IsRoundToNearestSupported() const420 bool GpuInfo::IsRoundToNearestSupported() const {
421 if (IsApiOpenCl()) {
422 return opencl_info.supports_fp16_rtn || opencl_info.supports_fp32_rtn;
423 }
424 if (IsApple()) {
425 return apple_info.IsRoundToNearestSupported();
426 }
427 if (IsAdreno()) {
428 if (adreno_info.IsAdreno1xx() || adreno_info.IsAdreno2xx() ||
429 adreno_info.IsAdreno3xx()) {
430 return false;
431 }
432 }
433 if (IsPowerVR()) {
434 return false;
435 }
436 return true;
437 }
438
SupportsFP16() const439 bool GpuInfo::SupportsFP16() const {
440 if (IsApiOpenCl()) {
441 return opencl_info.supports_fp16;
442 }
443 return true;
444 }
445
SupportsTextureArray() const446 bool GpuInfo::SupportsTextureArray() const {
447 if (!SupportsImages()) {
448 return false;
449 }
450 if (IsApiOpenCl()) {
451 return opencl_info.cl_version >= OpenClVersion::kCl1_2;
452 }
453 return true;
454 }
455
SupportsImageBuffer() const456 bool GpuInfo::SupportsImageBuffer() const {
457 if (!SupportsImages()) {
458 return false;
459 }
460 if (IsApiOpenCl()) {
461 return opencl_info.cl_version >= OpenClVersion::kCl1_2;
462 }
463 return true;
464 }
465
SupportsImage3D() const466 bool GpuInfo::SupportsImage3D() const {
467 if (!SupportsImages()) {
468 return false;
469 }
470 if (IsApiOpenCl()) {
471 if (IsMali() && mali_info.IsMidgard()) {
472 // On Mali T880 read_imageh doesn't compile with image3d_t
473 return false;
474 }
475 return opencl_info.supports_image3d_writes;
476 }
477 return true;
478 }
479
SupportsImages() const480 bool GpuInfo::SupportsImages() const {
481 if (IsApiOpenCl()) {
482 return opencl_info.supports_images;
483 }
484 return true;
485 }
486
SupportsPointersInKernels() const487 bool GpuInfo::SupportsPointersInKernels() const {
488 return IsApiOpenCl() || IsApiMetal();
489 }
490
IsWaveSizeEqualTo32() const491 bool GpuInfo::IsWaveSizeEqualTo32() const {
492 return supported_subgroup_sizes.size() == 1 &&
493 supported_subgroup_sizes[0] == 32;
494 }
495
SupportsExtension(const std::string & extension) const496 bool GpuInfo::SupportsExtension(const std::string& extension) const {
497 const std::vector<std::string>* extensions = nullptr;
498 if (IsApiOpenGl()) {
499 extensions = &opengl_info.extensions;
500 } else if (IsApiVulkan()) {
501 extensions = &vulkan_info.extensions;
502 } else if (IsApiOpenCl()) {
503 extensions = &opencl_info.extensions;
504 }
505 if (!extensions) {
506 return false;
507 }
508 for (const auto& ext : *extensions) {
509 if (ext == extension) {
510 return true;
511 }
512 }
513 return false;
514 }
515
SupportsSubGroupWithSize(int sub_group_size) const516 bool GpuInfo::SupportsSubGroupWithSize(int sub_group_size) const {
517 for (auto subgroup_size : supported_subgroup_sizes) {
518 if (sub_group_size == subgroup_size) {
519 return true;
520 }
521 }
522 return false;
523 }
524
SupportsFloatImage2D(DataType data_type,int channels) const525 bool GpuInfo::SupportsFloatImage2D(DataType data_type, int channels) const {
526 if (IsApiOpenCl()) {
527 if (channels == 1) {
528 return data_type == DataType::FLOAT32 ? opencl_info.supports_r_f32_tex2d
529 : opencl_info.supports_r_f16_tex2d;
530 } else if (channels == 2) {
531 return data_type == DataType::FLOAT32 ? opencl_info.supports_rg_f32_tex2d
532 : opencl_info.supports_rg_f16_tex2d;
533 } else if (channels == 3) {
534 return data_type == DataType::FLOAT32
535 ? opencl_info.supports_rgb_f32_tex2d
536 : opencl_info.supports_rgb_f16_tex2d;
537 } else if (channels == 4) {
538 return data_type == DataType::FLOAT32
539 ? opencl_info.supports_rgba_f32_tex2d
540 : opencl_info.supports_rgba_f16_tex2d;
541 } else {
542 return false;
543 }
544 }
545 return false;
546 }
547
GetComputeUnitsCount() const548 int GpuInfo::GetComputeUnitsCount() const {
549 if (IsApiOpenCl()) {
550 return opencl_info.compute_units_count;
551 }
552 if (IsApple()) {
553 return apple_info.GetComputeUnitsCount();
554 }
555 if (IsAMD() && IsApiVulkan()) {
556 return amd_info.GetComputeUnitsCount();
557 }
558 return 1;
559 }
560
GetMaxWorkGroupSizeForX() const561 int GpuInfo::GetMaxWorkGroupSizeForX() const {
562 if (IsApiOpenGl()) {
563 return opengl_info.max_compute_work_group_size_x;
564 }
565 if (IsApiVulkan()) {
566 return vulkan_info.max_compute_work_group_size_x;
567 }
568 if (IsApiOpenCl()) {
569 return opencl_info.max_work_group_size_x;
570 }
571 if (IsApiMetal()) {
572 return metal_info.max_work_group_size_x;
573 }
574 return 256;
575 }
576
GetMaxWorkGroupSizeForY() const577 int GpuInfo::GetMaxWorkGroupSizeForY() const {
578 if (IsApiOpenGl()) {
579 return opengl_info.max_compute_work_group_size_y;
580 }
581 if (IsApiVulkan()) {
582 return vulkan_info.max_compute_work_group_size_y;
583 }
584 if (IsApiOpenCl()) {
585 return opencl_info.max_work_group_size_y;
586 }
587 if (IsApiMetal()) {
588 return metal_info.max_work_group_size_y;
589 }
590 return 256;
591 }
592
GetMaxWorkGroupSizeForZ() const593 int GpuInfo::GetMaxWorkGroupSizeForZ() const {
594 if (IsApiOpenGl()) {
595 return opengl_info.max_compute_work_group_size_z;
596 }
597 if (IsApiVulkan()) {
598 return vulkan_info.max_compute_work_group_size_z;
599 }
600 if (IsApiOpenCl()) {
601 return opencl_info.max_work_group_size_z;
602 }
603 if (IsApiMetal()) {
604 return metal_info.max_work_group_size_z;
605 }
606 return 64;
607 }
608
GetMaxWorkGroupTotalSize() const609 int GpuInfo::GetMaxWorkGroupTotalSize() const {
610 if (IsApiOpenGl()) {
611 return opengl_info.max_work_group_invocations;
612 }
613 if (IsApiVulkan()) {
614 return vulkan_info.max_compute_work_group_invocations;
615 }
616 if (IsApiOpenCl()) {
617 return opencl_info.max_work_group_total_size;
618 }
619 if (IsApiMetal()) {
620 int max_size = metal_info.max_work_group_size_x;
621 max_size = std::max(max_size, metal_info.max_work_group_size_y);
622 max_size = std::max(max_size, metal_info.max_work_group_size_z);
623 return max_size;
624 }
625 return 256;
626 }
627
GetMaxImage2DWidth() const628 uint64_t GpuInfo::GetMaxImage2DWidth() const {
629 if (IsApiOpenGl()) {
630 return opengl_info.max_texture_size;
631 }
632 if (IsApiVulkan()) {
633 return vulkan_info.max_image_dimension_2d;
634 }
635 if (IsApiOpenCl()) {
636 return opencl_info.image2d_max_width;
637 }
638 return 2048;
639 }
640
GetMaxImage2DHeight() const641 uint64_t GpuInfo::GetMaxImage2DHeight() const {
642 if (IsApiOpenGl()) {
643 return opengl_info.max_texture_size;
644 }
645 if (IsApiVulkan()) {
646 return vulkan_info.max_image_dimension_2d;
647 }
648 if (IsApiOpenCl()) {
649 return opencl_info.image2d_max_height;
650 }
651 return 2048;
652 }
653
GetMaxImage2DArrayLayers() const654 uint64_t GpuInfo::GetMaxImage2DArrayLayers() const {
655 if (IsApiOpenGl()) {
656 return opengl_info.max_array_texture_layers;
657 }
658 if (IsApiVulkan()) {
659 return vulkan_info.max_image_array_layers;
660 }
661 if (IsApiOpenCl()) {
662 return opencl_info.image_array_max_layers;
663 }
664 return 256;
665 }
666
GetMaxImage3DWidth() const667 uint64_t GpuInfo::GetMaxImage3DWidth() const {
668 if (IsApiOpenCl()) {
669 return opencl_info.image3d_max_width;
670 }
671 return 256;
672 }
673
GetMaxImage3DHeight() const674 uint64_t GpuInfo::GetMaxImage3DHeight() const {
675 if (IsApiOpenCl()) {
676 return opencl_info.image3d_max_height;
677 }
678 return 256;
679 }
680
GetMaxImage3DDepth() const681 uint64_t GpuInfo::GetMaxImage3DDepth() const {
682 if (IsApiOpenCl()) {
683 return opencl_info.image3d_max_depth;
684 }
685 return 256;
686 }
687
GetMaxBufferSize() const688 uint64_t GpuInfo::GetMaxBufferSize() const {
689 if (IsApiOpenCl()) {
690 return opencl_info.buffer_max_size;
691 } else if (IsApiMetal()) {
692 return metal_info.buffer_max_size;
693 }
694 return 128 * 1024 * 1024;
695 }
696
GetMaxMemoryAllocationSize() const697 uint64_t GpuInfo::GetMaxMemoryAllocationSize() const {
698 if (IsApiOpenCl()) {
699 return opencl_info.max_allocation_size;
700 } else if (IsApiMetal()) {
701 return metal_info.buffer_max_size;
702 }
703 return 128 * 1024 * 1024;
704 }
705
GetMaxImageBufferWidth() const706 uint64_t GpuInfo::GetMaxImageBufferWidth() const {
707 if (IsApiOpenCl()) {
708 return opencl_info.image_buffer_max_size;
709 }
710 return 64 * 1024;
711 }
712
GetMaxImageArguments() const713 int GpuInfo::GetMaxImageArguments() const {
714 if (IsApiOpenGl()) {
715 return opengl_info.max_image_units;
716 }
717 if (IsApiVulkan()) {
718 return vulkan_info.max_per_stage_descriptor_sampled_images;
719 }
720 if (IsApiMetal()) {
721 return 32;
722 }
723 if (IsApiOpenCl()) {
724 return 128;
725 }
726 return 1;
727 }
728
IsApiOpenGl() const729 bool GpuInfo::IsApiOpenGl() const { return gpu_api == GpuApi::kOpenGl; }
730
IsApiOpenGl31OrAbove() const731 bool GpuInfo::IsApiOpenGl31OrAbove() const {
732 if (!IsApiOpenGl()) {
733 return false;
734 }
735 return (opengl_info.major_version == 3 && opengl_info.minor_version >= 1) ||
736 opengl_info.major_version > 3;
737 }
738
IsApiVulkan() const739 bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; }
740
IsApiMetal() const741 bool GpuInfo::IsApiMetal() const { return gpu_api == GpuApi::kMetal; }
742
IsApiOpenCl() const743 bool GpuInfo::IsApiOpenCl() const { return gpu_api == GpuApi::kOpenCl; }
744
IsGlsl() const745 bool GpuInfo::IsGlsl() const { return IsApiOpenGl() || IsApiVulkan(); }
746
IsCL11OrHigher() const747 bool GpuInfo::IsCL11OrHigher() const {
748 if (!IsApiOpenCl()) {
749 return false;
750 }
751 return opencl_info.cl_version != OpenClVersion::kCl1_0;
752 }
753
IsCL20OrHigher() const754 bool GpuInfo::IsCL20OrHigher() const {
755 if (!IsApiOpenCl()) {
756 return false;
757 }
758 return opencl_info.cl_version != OpenClVersion::kCl1_0 &&
759 opencl_info.cl_version != OpenClVersion::kCl1_1 &&
760 opencl_info.cl_version != OpenClVersion::kCl1_2;
761 }
762
IsCL30OrHigher() const763 bool GpuInfo::IsCL30OrHigher() const {
764 if (!IsApiOpenCl()) {
765 return false;
766 }
767 return IsCL20OrHigher() && opencl_info.cl_version != OpenClVersion::kCl2_0 &&
768 opencl_info.cl_version != OpenClVersion::kCl2_1 &&
769 opencl_info.cl_version != OpenClVersion::kCl2_2;
770 }
771
772 } // namespace gpu
773 } // namespace tflite
774