• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
17 
18 #include <cfloat>
19 #include <string>
20 
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/util.h"
24 
25 namespace tflite {
26 namespace gpu {
27 namespace {
GetGlslConversion(const GpuInfo & gpu_info,DataType src_type,DataType dst_type,int vec_size)28 std::string GetGlslConversion(const GpuInfo& gpu_info, DataType src_type,
29                               DataType dst_type, int vec_size) {
30   if (src_type == dst_type) {
31     return "";
32   }
33   bool need_explicit_conversion = true;
34   switch (dst_type) {
35     case DataType::FLOAT32:
36     case DataType::FLOAT16:
37       if (gpu_info.IsGlslSupportsExplicitFp16()) {
38         if (src_type == dst_type) {
39           need_explicit_conversion = false;
40         }
41       } else {
42         if (src_type == DataType::FLOAT32 || src_type == DataType::FLOAT16) {
43           need_explicit_conversion = false;
44         }
45       }
46       break;
47     case DataType::INT32:
48     case DataType::INT16:
49     case DataType::INT8:
50       if (src_type == DataType::INT32 || src_type == DataType::INT16 ||
51           src_type == DataType::INT8) {
52         need_explicit_conversion = false;
53       }
54       break;
55     case DataType::UINT32:
56     case DataType::UINT16:
57     case DataType::UINT8:
58       if (src_type == DataType::UINT32 || src_type == DataType::UINT16 ||
59           src_type == DataType::UINT8) {
60         need_explicit_conversion = false;
61       }
62       break;
63     case DataType::BOOL:
64       need_explicit_conversion = true;
65       break;
66     default:
67       break;
68   }
69   if (need_explicit_conversion) {
70     return ToGlslShaderDataType(
71         dst_type, vec_size,
72         /*add_precision*/ false,
73         /*explicit_fp16*/ gpu_info.IsGlslSupportsExplicitFp16());
74   } else {
75     return "";
76   }
77 }
78 }  // namespace
79 
MemoryTypeToCLType(MemoryType type)80 std::string MemoryTypeToCLType(MemoryType type) {
81   switch (type) {
82     case MemoryType::GLOBAL:
83       return "__global";
84     case MemoryType::CONSTANT:
85       return "__constant";
86     case MemoryType::LOCAL:
87       return "__local";
88   }
89   return "";
90 }
91 
MemoryTypeToMetalType(MemoryType type)92 std::string MemoryTypeToMetalType(MemoryType type) {
93   switch (type) {
94     case MemoryType::GLOBAL:
95       return "device";
96     case MemoryType::CONSTANT:
97       return "constant";
98       break;
99     case MemoryType::LOCAL:
100       return "threadgroup";
101   }
102   return "";
103 }
104 
GetXStrideCorrected(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)105 std::string GetXStrideCorrected(const std::string& src_x,
106                                 const std::string& batch_size,
107                                 const std::string& stride_x,
108                                 const std::string& padding_x) {
109   // int p0 = src_x / batch_size;\n";
110   // int b0 = src_x % batch_size;\n";
111   // return p0 * stride_x * batch_size + b0 + padding_x;\n";
112   return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
113                           batch_size, stride_x, padding_x);
114 }
115 
GetXStrideCorrectedV2(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)116 std::string GetXStrideCorrectedV2(const std::string& src_x,
117                                   const std::string& batch_size,
118                                   const std::string& stride_x,
119                                   const std::string& padding_x) {
120   // int p0 = src_x / batch_size;\n";
121   // int b0 = src_x % batch_size;\n";
122   // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
123   return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
124                           batch_size, stride_x, padding_x);
125 }
126 
GetMaskForLastPlane(int channels)127 float4 GetMaskForLastPlane(int channels) {
128   float4 mask = float4(0.0f);
129   const int reminder = channels % 4 == 0 ? 4 : channels % 4;
130   for (int i = 0; i < reminder; ++i) {
131     mask[i] = 1.0f;
132   }
133   return mask;
134 }
135 
GetRecommendedBlockSizeForConv(const GpuInfo & gpu_info,CalculationsPrecision precision,int task_size)136 int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
137                                    CalculationsPrecision precision,
138                                    int task_size) {
139   const float task_size_per_cu =
140       task_size / static_cast<float>(gpu_info.GetComputeUnitsCount());
141   int block_size = 1;
142   float threshold_1 = FLT_MAX;
143   float threshold_2 = FLT_MAX;
144   float threshold_4 = FLT_MAX;
145   if (!gpu_info.IsMali()) {
146     return 1;
147   }
148   MaliInfo mali_info = gpu_info.mali_info;
149   switch (precision) {
150     case CalculationsPrecision::F16:
151       if (mali_info.IsBifrostGen1()) {
152         threshold_1 = 256.0f;
153         threshold_2 = 256.0f * 4.0f;
154         threshold_4 = 256.0f * 8.0f;
155       } else if (mali_info.IsBifrostGen2()) {
156         threshold_1 = 256.0f * 2.0f;
157         threshold_2 = 256.0f * 8.0f;
158         threshold_4 = 256.0f * 16.0f;
159       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
160         threshold_1 = 256.0f;
161         threshold_2 = 256.0f * 6.0f;
162         threshold_4 = 256.0f * 16.0f;
163       } else if (mali_info.IsMidgard()) {
164         threshold_1 = 256.0f * 4.0f;
165         threshold_2 = 256.0f * 16.0f;
166       }
167       break;
168     case CalculationsPrecision::F32_F16:
169       if (mali_info.IsBifrostGen1()) {
170         threshold_1 = 256.0f;
171         threshold_2 = 256.0f * 3.0f;
172         threshold_4 = 256.0f * 32.0f;
173       } else if (mali_info.IsBifrostGen2()) {
174         threshold_1 = 256.0f * 2.0f;
175         threshold_2 = 256.0f * 8.0f;
176       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
177         threshold_1 = 256.0f;
178         threshold_2 = 256.0f * 8.0f;
179       } else if (mali_info.IsMidgard()) {
180         threshold_1 = 256.0f * 4.0f;
181       }
182       break;
183     case CalculationsPrecision::F32:
184       if (mali_info.IsBifrostGen1()) {
185         threshold_1 = 256.0f;
186         threshold_2 = 256.0f * 4.0f;
187       } else if (mali_info.IsBifrostGen2()) {
188         threshold_1 = 128.0f;
189         threshold_2 = 256.0f * 4.0f;
190       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
191         threshold_1 = 256.0f;
192         threshold_2 = 256.0f * 12.0f;
193       } else if (mali_info.IsMidgard()) {
194         threshold_1 = 256.0f * 16.0f;
195       }
196       break;
197   }
198   if (task_size_per_cu <= threshold_1) {
199     block_size = 1;
200   } else if (task_size_per_cu <= threshold_2) {
201     block_size = 2;
202   } else if (task_size_per_cu <= threshold_4) {
203     block_size = 4;
204   } else {
205     block_size = 8;
206   }
207   return block_size;
208 }
209 
GetWorkGroupsCount(const int3 & grid_size,const int3 & work_group_size)210 int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
211   int3 work_groups_count;
212   work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
213   work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
214   work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
215   return work_groups_count;
216 }
217 
GetTypeDeclaration(const GpuInfo & gpu_info,DataType data_type,int vec_size)218 std::string GetTypeDeclaration(const GpuInfo& gpu_info, DataType data_type,
219                                int vec_size) {
220   if (gpu_info.IsApiOpenCl()) {
221     return ToCLDataType(data_type, vec_size);
222   } else if (gpu_info.IsApiMetal()) {
223     return ToMetalDataType(data_type, vec_size);
224   } else if (gpu_info.IsGlsl()) {
225     return ToGlslShaderDataType(data_type, vec_size, true,
226                                 gpu_info.IsGlslSupportsExplicitFp16());
227   } else {
228     return "";
229   }
230 }
231 
GetZeroValue(const GpuInfo & gpu_info,DataType data_type,int vec_size)232 std::string GetZeroValue(const GpuInfo& gpu_info, DataType data_type,
233                          int vec_size) {
234   if (gpu_info.IsApiOpenCl()) {
235     return "(" + ToCLDataType(data_type, vec_size) + ")(0)";
236   } else if (gpu_info.IsApiMetal()) {
237     return ToMetalDataType(data_type, vec_size) + "(0)";
238   } else if (gpu_info.IsGlsl()) {
239     return ToGlslShaderDataType(data_type, vec_size, false,
240                                 gpu_info.IsGlslSupportsExplicitFp16()) +
241            "(0)";
242   } else {
243     return "";
244   }
245 }
246 
GetOneValue(const GpuInfo & gpu_info,DataType data_type,int vec_size)247 std::string GetOneValue(const GpuInfo& gpu_info, DataType data_type,
248                         int vec_size) {
249   if (gpu_info.IsApiOpenCl()) {
250     return "(" + ToCLDataType(data_type, vec_size) + ")(1)";
251   } else if (gpu_info.IsApiMetal()) {
252     return ToMetalDataType(data_type, vec_size) + "(1)";
253   } else if (gpu_info.IsGlsl()) {
254     return ToGlslShaderDataType(data_type, vec_size, false,
255                                 gpu_info.IsGlslSupportsExplicitFp16()) +
256            "(1)";
257   } else {
258     return "";
259   }
260 }
261 
GetTypeConversion(const GpuInfo & gpu_info,DataType src_type,DataType dst_type,int vec_size)262 std::string GetTypeConversion(const GpuInfo& gpu_info, DataType src_type,
263                               DataType dst_type, int vec_size) {
264   if (src_type != dst_type) {
265     if (gpu_info.IsApiOpenCl()) {
266       if (dst_type == DataType::BOOL && vec_size != 1) {
267         // In OpenCL for bool4 we are using uchar4
268         // From OpenCL specification for "Relational and Equality Operators":
269         //   "These functions shall return a 0 if the specified relation is
270         //   false and a -1 (i.e. all bits set) if the specified relation is
271         //   true for vector argument types."
272         // (convert_uchar4((value) != 0) & (uchar4)(1))
273         return "(convert_" + ToCLDataType(DataType::UINT8, vec_size) +
274                "(($0) != " + GetZeroValue(gpu_info, src_type, vec_size) +
275                ") & " + GetOneValue(gpu_info, DataType::UINT8, vec_size) + ")";
276       } else {
277         return "convert_" + ToCLDataType(dst_type, vec_size) + "($0)";
278       }
279     } else if (gpu_info.IsApiMetal()) {
280       return ToMetalDataType(dst_type, vec_size) + "($0)";
281     } else if (gpu_info.IsGlsl()) {
282       const std::string conversion =
283           GetGlslConversion(gpu_info, src_type, dst_type, vec_size);
284       if (!conversion.empty()) {
285         return conversion + "($0)";
286       } else {
287         return "$0";
288       }
289     }
290   }
291   return "$0";
292 }
293 
294 }  // namespace gpu
295 }  // namespace tflite
296