• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
17 
18 #include <cfloat>
19 
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/lite/delegates/gpu/common/util.h"
22 
23 namespace tflite {
24 namespace gpu {
25 
MemoryTypeToCLType(MemoryType type)26 std::string MemoryTypeToCLType(MemoryType type) {
27   switch (type) {
28     case MemoryType::GLOBAL:
29       return "__global";
30     case MemoryType::CONSTANT:
31       return "__constant";
32     case MemoryType::LOCAL:
33       return "__local";
34   }
35   return "";
36 }
37 
MemoryTypeToMetalType(MemoryType type)38 std::string MemoryTypeToMetalType(MemoryType type) {
39   switch (type) {
40     case MemoryType::GLOBAL:
41       return "device";
42     case MemoryType::CONSTANT:
43       return "constant";
44       break;
45     case MemoryType::LOCAL:
46       return "threadgroup";
47   }
48   return "";
49 }
50 
GetXStrideCorrected(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)51 std::string GetXStrideCorrected(const std::string& src_x,
52                                 const std::string& batch_size,
53                                 const std::string& stride_x,
54                                 const std::string& padding_x) {
55   // int p0 = src_x / batch_size;\n";
56   // int b0 = src_x % batch_size;\n";
57   // return p0 * stride_x * batch_size + b0 + padding_x;\n";
58   return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
59                           batch_size, stride_x, padding_x);
60 }
61 
GetXStrideCorrectedV2(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)62 std::string GetXStrideCorrectedV2(const std::string& src_x,
63                                   const std::string& batch_size,
64                                   const std::string& stride_x,
65                                   const std::string& padding_x) {
66   // int p0 = src_x / batch_size;\n";
67   // int b0 = src_x % batch_size;\n";
68   // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
69   return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
70                           batch_size, stride_x, padding_x);
71 }
72 
GetMaskForLastPlane(int channels)73 float4 GetMaskForLastPlane(int channels) {
74   float4 mask = float4(0.0f);
75   const int reminder = channels % 4 == 0 ? 4 : channels % 4;
76   for (int i = 0; i < reminder; ++i) {
77     mask[i] = 1.0f;
78   }
79   return mask;
80 }
81 
GetRecommendedBlockSizeForConv(const GpuInfo & gpu_info,CalculationsPrecision precision,int task_size)82 int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
83                                    CalculationsPrecision precision,
84                                    int task_size) {
85   const float task_size_per_cu =
86       task_size / static_cast<float>(gpu_info.GetComputeUnitsCount());
87   int block_size = 1;
88   float threshold_1 = FLT_MAX;
89   float threshold_2 = FLT_MAX;
90   float threshold_4 = FLT_MAX;
91   if (!gpu_info.IsMali()) {
92     return 1;
93   }
94   MaliInfo mali_info = gpu_info.mali_info;
95   switch (precision) {
96     case CalculationsPrecision::F16:
97       if (mali_info.IsBifrostGen1()) {
98         threshold_1 = 256.0f;
99         threshold_2 = 256.0f * 4.0f;
100         threshold_4 = 256.0f * 8.0f;
101       } else if (mali_info.IsBifrostGen2()) {
102         threshold_1 = 256.0f * 2.0f;
103         threshold_2 = 256.0f * 8.0f;
104         threshold_4 = 256.0f * 16.0f;
105       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
106         threshold_1 = 256.0f;
107         threshold_2 = 256.0f * 6.0f;
108         threshold_4 = 256.0f * 16.0f;
109       } else if (mali_info.IsMidgard()) {
110         threshold_1 = 256.0f * 4.0f;
111         threshold_2 = 256.0f * 16.0f;
112       }
113       break;
114     case CalculationsPrecision::F32_F16:
115       if (mali_info.IsBifrostGen1()) {
116         threshold_1 = 256.0f;
117         threshold_2 = 256.0f * 3.0f;
118         threshold_4 = 256.0f * 32.0f;
119       } else if (mali_info.IsBifrostGen2()) {
120         threshold_1 = 256.0f * 2.0f;
121         threshold_2 = 256.0f * 8.0f;
122       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
123         threshold_1 = 256.0f;
124         threshold_2 = 256.0f * 8.0f;
125       } else if (mali_info.IsMidgard()) {
126         threshold_1 = 256.0f * 4.0f;
127       }
128       break;
129     case CalculationsPrecision::F32:
130       if (mali_info.IsBifrostGen1()) {
131         threshold_1 = 256.0f;
132         threshold_2 = 256.0f * 4.0f;
133       } else if (mali_info.IsBifrostGen2()) {
134         threshold_1 = 128.0f;
135         threshold_2 = 256.0f * 4.0f;
136       } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
137         threshold_1 = 256.0f;
138         threshold_2 = 256.0f * 12.0f;
139       } else if (mali_info.IsMidgard()) {
140         threshold_1 = 256.0f * 16.0f;
141       }
142       break;
143   }
144   if (task_size_per_cu <= threshold_1) {
145     block_size = 1;
146   } else if (task_size_per_cu <= threshold_2) {
147     block_size = 2;
148   } else if (task_size_per_cu <= threshold_4) {
149     block_size = 4;
150   } else {
151     block_size = 8;
152   }
153   return block_size;
154 }
155 
GetWorkGroupsCount(const int3 & grid_size,const int3 & work_group_size)156 int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
157   int3 work_groups_count;
158   work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
159   work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
160   work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
161   return work_groups_count;
162 }
163 
164 }  // namespace gpu
165 }  // namespace tflite
166