1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
17
18 #include <cfloat>
19 #include <string>
20
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/util.h"
24
25 namespace tflite {
26 namespace gpu {
27 namespace {
GetGlslConversion(const GpuInfo & gpu_info,DataType src_type,DataType dst_type,int vec_size)28 std::string GetGlslConversion(const GpuInfo& gpu_info, DataType src_type,
29 DataType dst_type, int vec_size) {
30 if (src_type == dst_type) {
31 return "";
32 }
33 bool need_explicit_conversion = true;
34 switch (dst_type) {
35 case DataType::FLOAT32:
36 case DataType::FLOAT16:
37 if (gpu_info.IsGlslSupportsExplicitFp16()) {
38 if (src_type == dst_type) {
39 need_explicit_conversion = false;
40 }
41 } else {
42 if (src_type == DataType::FLOAT32 || src_type == DataType::FLOAT16) {
43 need_explicit_conversion = false;
44 }
45 }
46 break;
47 case DataType::INT32:
48 case DataType::INT16:
49 case DataType::INT8:
50 if (src_type == DataType::INT32 || src_type == DataType::INT16 ||
51 src_type == DataType::INT8) {
52 need_explicit_conversion = false;
53 }
54 break;
55 case DataType::UINT32:
56 case DataType::UINT16:
57 case DataType::UINT8:
58 if (src_type == DataType::UINT32 || src_type == DataType::UINT16 ||
59 src_type == DataType::UINT8) {
60 need_explicit_conversion = false;
61 }
62 break;
63 case DataType::BOOL:
64 need_explicit_conversion = true;
65 break;
66 default:
67 break;
68 }
69 if (need_explicit_conversion) {
70 return ToGlslShaderDataType(
71 dst_type, vec_size,
72 /*add_precision*/ false,
73 /*explicit_fp16*/ gpu_info.IsGlslSupportsExplicitFp16());
74 } else {
75 return "";
76 }
77 }
78 } // namespace
79
MemoryTypeToCLType(MemoryType type)80 std::string MemoryTypeToCLType(MemoryType type) {
81 switch (type) {
82 case MemoryType::GLOBAL:
83 return "__global";
84 case MemoryType::CONSTANT:
85 return "__constant";
86 case MemoryType::LOCAL:
87 return "__local";
88 }
89 return "";
90 }
91
MemoryTypeToMetalType(MemoryType type)92 std::string MemoryTypeToMetalType(MemoryType type) {
93 switch (type) {
94 case MemoryType::GLOBAL:
95 return "device";
96 case MemoryType::CONSTANT:
97 return "constant";
98 break;
99 case MemoryType::LOCAL:
100 return "threadgroup";
101 }
102 return "";
103 }
104
GetXStrideCorrected(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)105 std::string GetXStrideCorrected(const std::string& src_x,
106 const std::string& batch_size,
107 const std::string& stride_x,
108 const std::string& padding_x) {
109 // int p0 = src_x / batch_size;\n";
110 // int b0 = src_x % batch_size;\n";
111 // return p0 * stride_x * batch_size + b0 + padding_x;\n";
112 return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
113 batch_size, stride_x, padding_x);
114 }
115
GetXStrideCorrectedV2(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)116 std::string GetXStrideCorrectedV2(const std::string& src_x,
117 const std::string& batch_size,
118 const std::string& stride_x,
119 const std::string& padding_x) {
120 // int p0 = src_x / batch_size;\n";
121 // int b0 = src_x % batch_size;\n";
122 // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
123 return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
124 batch_size, stride_x, padding_x);
125 }
126
GetMaskForLastPlane(int channels)127 float4 GetMaskForLastPlane(int channels) {
128 float4 mask = float4(0.0f);
129 const int reminder = channels % 4 == 0 ? 4 : channels % 4;
130 for (int i = 0; i < reminder; ++i) {
131 mask[i] = 1.0f;
132 }
133 return mask;
134 }
135
GetRecommendedBlockSizeForConv(const GpuInfo & gpu_info,CalculationsPrecision precision,int task_size)136 int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
137 CalculationsPrecision precision,
138 int task_size) {
139 const float task_size_per_cu =
140 task_size / static_cast<float>(gpu_info.GetComputeUnitsCount());
141 int block_size = 1;
142 float threshold_1 = FLT_MAX;
143 float threshold_2 = FLT_MAX;
144 float threshold_4 = FLT_MAX;
145 if (!gpu_info.IsMali()) {
146 return 1;
147 }
148 MaliInfo mali_info = gpu_info.mali_info;
149 switch (precision) {
150 case CalculationsPrecision::F16:
151 if (mali_info.IsBifrostGen1()) {
152 threshold_1 = 256.0f;
153 threshold_2 = 256.0f * 4.0f;
154 threshold_4 = 256.0f * 8.0f;
155 } else if (mali_info.IsBifrostGen2()) {
156 threshold_1 = 256.0f * 2.0f;
157 threshold_2 = 256.0f * 8.0f;
158 threshold_4 = 256.0f * 16.0f;
159 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
160 threshold_1 = 256.0f;
161 threshold_2 = 256.0f * 6.0f;
162 threshold_4 = 256.0f * 16.0f;
163 } else if (mali_info.IsMidgard()) {
164 threshold_1 = 256.0f * 4.0f;
165 threshold_2 = 256.0f * 16.0f;
166 }
167 break;
168 case CalculationsPrecision::F32_F16:
169 if (mali_info.IsBifrostGen1()) {
170 threshold_1 = 256.0f;
171 threshold_2 = 256.0f * 3.0f;
172 threshold_4 = 256.0f * 32.0f;
173 } else if (mali_info.IsBifrostGen2()) {
174 threshold_1 = 256.0f * 2.0f;
175 threshold_2 = 256.0f * 8.0f;
176 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
177 threshold_1 = 256.0f;
178 threshold_2 = 256.0f * 8.0f;
179 } else if (mali_info.IsMidgard()) {
180 threshold_1 = 256.0f * 4.0f;
181 }
182 break;
183 case CalculationsPrecision::F32:
184 if (mali_info.IsBifrostGen1()) {
185 threshold_1 = 256.0f;
186 threshold_2 = 256.0f * 4.0f;
187 } else if (mali_info.IsBifrostGen2()) {
188 threshold_1 = 128.0f;
189 threshold_2 = 256.0f * 4.0f;
190 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
191 threshold_1 = 256.0f;
192 threshold_2 = 256.0f * 12.0f;
193 } else if (mali_info.IsMidgard()) {
194 threshold_1 = 256.0f * 16.0f;
195 }
196 break;
197 }
198 if (task_size_per_cu <= threshold_1) {
199 block_size = 1;
200 } else if (task_size_per_cu <= threshold_2) {
201 block_size = 2;
202 } else if (task_size_per_cu <= threshold_4) {
203 block_size = 4;
204 } else {
205 block_size = 8;
206 }
207 return block_size;
208 }
209
GetWorkGroupsCount(const int3 & grid_size,const int3 & work_group_size)210 int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
211 int3 work_groups_count;
212 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
213 work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
214 work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
215 return work_groups_count;
216 }
217
GetTypeDeclaration(const GpuInfo & gpu_info,DataType data_type,int vec_size)218 std::string GetTypeDeclaration(const GpuInfo& gpu_info, DataType data_type,
219 int vec_size) {
220 if (gpu_info.IsApiOpenCl()) {
221 return ToCLDataType(data_type, vec_size);
222 } else if (gpu_info.IsApiMetal()) {
223 return ToMetalDataType(data_type, vec_size);
224 } else if (gpu_info.IsGlsl()) {
225 return ToGlslShaderDataType(data_type, vec_size, true,
226 gpu_info.IsGlslSupportsExplicitFp16());
227 } else {
228 return "";
229 }
230 }
231
GetZeroValue(const GpuInfo & gpu_info,DataType data_type,int vec_size)232 std::string GetZeroValue(const GpuInfo& gpu_info, DataType data_type,
233 int vec_size) {
234 if (gpu_info.IsApiOpenCl()) {
235 return "(" + ToCLDataType(data_type, vec_size) + ")(0)";
236 } else if (gpu_info.IsApiMetal()) {
237 return ToMetalDataType(data_type, vec_size) + "(0)";
238 } else if (gpu_info.IsGlsl()) {
239 return ToGlslShaderDataType(data_type, vec_size, false,
240 gpu_info.IsGlslSupportsExplicitFp16()) +
241 "(0)";
242 } else {
243 return "";
244 }
245 }
246
GetOneValue(const GpuInfo & gpu_info,DataType data_type,int vec_size)247 std::string GetOneValue(const GpuInfo& gpu_info, DataType data_type,
248 int vec_size) {
249 if (gpu_info.IsApiOpenCl()) {
250 return "(" + ToCLDataType(data_type, vec_size) + ")(1)";
251 } else if (gpu_info.IsApiMetal()) {
252 return ToMetalDataType(data_type, vec_size) + "(1)";
253 } else if (gpu_info.IsGlsl()) {
254 return ToGlslShaderDataType(data_type, vec_size, false,
255 gpu_info.IsGlslSupportsExplicitFp16()) +
256 "(1)";
257 } else {
258 return "";
259 }
260 }
261
GetTypeConversion(const GpuInfo & gpu_info,DataType src_type,DataType dst_type,int vec_size)262 std::string GetTypeConversion(const GpuInfo& gpu_info, DataType src_type,
263 DataType dst_type, int vec_size) {
264 if (src_type != dst_type) {
265 if (gpu_info.IsApiOpenCl()) {
266 if (dst_type == DataType::BOOL && vec_size != 1) {
267 // In OpenCL for bool4 we are using uchar4
268 // From OpenCL specification for "Relational and Equality Operators":
269 // "These functions shall return a 0 if the specified relation is
270 // false and a -1 (i.e. all bits set) if the specified relation is
271 // true for vector argument types."
272 // (convert_uchar4((value) != 0) & (uchar4)(1))
273 return "(convert_" + ToCLDataType(DataType::UINT8, vec_size) +
274 "(($0) != " + GetZeroValue(gpu_info, src_type, vec_size) +
275 ") & " + GetOneValue(gpu_info, DataType::UINT8, vec_size) + ")";
276 } else {
277 return "convert_" + ToCLDataType(dst_type, vec_size) + "($0)";
278 }
279 } else if (gpu_info.IsApiMetal()) {
280 return ToMetalDataType(dst_type, vec_size) + "($0)";
281 } else if (gpu_info.IsGlsl()) {
282 const std::string conversion =
283 GetGlslConversion(gpu_info, src_type, dst_type, vec_size);
284 if (!conversion.empty()) {
285 return conversion + "($0)";
286 } else {
287 return "$0";
288 }
289 }
290 }
291 return "$0";
292 }
293
294 } // namespace gpu
295 } // namespace tflite
296