1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
17
18 #include <cfloat>
19
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/lite/delegates/gpu/common/util.h"
22
23 namespace tflite {
24 namespace gpu {
25
MemoryTypeToCLType(MemoryType type)26 std::string MemoryTypeToCLType(MemoryType type) {
27 switch (type) {
28 case MemoryType::GLOBAL:
29 return "__global";
30 case MemoryType::CONSTANT:
31 return "__constant";
32 case MemoryType::LOCAL:
33 return "__local";
34 }
35 return "";
36 }
37
MemoryTypeToMetalType(MemoryType type)38 std::string MemoryTypeToMetalType(MemoryType type) {
39 switch (type) {
40 case MemoryType::GLOBAL:
41 return "device";
42 case MemoryType::CONSTANT:
43 return "constant";
44 break;
45 case MemoryType::LOCAL:
46 return "threadgroup";
47 }
48 return "";
49 }
50
GetXStrideCorrected(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)51 std::string GetXStrideCorrected(const std::string& src_x,
52 const std::string& batch_size,
53 const std::string& stride_x,
54 const std::string& padding_x) {
55 // int p0 = src_x / batch_size;\n";
56 // int b0 = src_x % batch_size;\n";
57 // return p0 * stride_x * batch_size + b0 + padding_x;\n";
58 return absl::Substitute("((($0) / $1) * $2 * $1 + (($0) % $1) + $3)", src_x,
59 batch_size, stride_x, padding_x);
60 }
61
GetXStrideCorrectedV2(const std::string & src_x,const std::string & batch_size,const std::string & stride_x,const std::string & padding_x)62 std::string GetXStrideCorrectedV2(const std::string& src_x,
63 const std::string& batch_size,
64 const std::string& stride_x,
65 const std::string& padding_x) {
66 // int p0 = src_x / batch_size;\n";
67 // int b0 = src_x % batch_size;\n";
68 // return (p0 * stride_x + padding_x) * batch_size + b0;\n";
69 return absl::Substitute("(((($0) / $1) * $2 + $3) * $1 + ($0) % $1)", src_x,
70 batch_size, stride_x, padding_x);
71 }
72
GetMaskForLastPlane(int channels)73 float4 GetMaskForLastPlane(int channels) {
74 float4 mask = float4(0.0f);
75 const int reminder = channels % 4 == 0 ? 4 : channels % 4;
76 for (int i = 0; i < reminder; ++i) {
77 mask[i] = 1.0f;
78 }
79 return mask;
80 }
81
GetRecommendedBlockSizeForConv(const GpuInfo & gpu_info,CalculationsPrecision precision,int task_size)82 int GetRecommendedBlockSizeForConv(const GpuInfo& gpu_info,
83 CalculationsPrecision precision,
84 int task_size) {
85 const float task_size_per_cu =
86 task_size / static_cast<float>(gpu_info.GetComputeUnitsCount());
87 int block_size = 1;
88 float threshold_1 = FLT_MAX;
89 float threshold_2 = FLT_MAX;
90 float threshold_4 = FLT_MAX;
91 if (!gpu_info.IsMali()) {
92 return 1;
93 }
94 MaliInfo mali_info = gpu_info.mali_info;
95 switch (precision) {
96 case CalculationsPrecision::F16:
97 if (mali_info.IsBifrostGen1()) {
98 threshold_1 = 256.0f;
99 threshold_2 = 256.0f * 4.0f;
100 threshold_4 = 256.0f * 8.0f;
101 } else if (mali_info.IsBifrostGen2()) {
102 threshold_1 = 256.0f * 2.0f;
103 threshold_2 = 256.0f * 8.0f;
104 threshold_4 = 256.0f * 16.0f;
105 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
106 threshold_1 = 256.0f;
107 threshold_2 = 256.0f * 6.0f;
108 threshold_4 = 256.0f * 16.0f;
109 } else if (mali_info.IsMidgard()) {
110 threshold_1 = 256.0f * 4.0f;
111 threshold_2 = 256.0f * 16.0f;
112 }
113 break;
114 case CalculationsPrecision::F32_F16:
115 if (mali_info.IsBifrostGen1()) {
116 threshold_1 = 256.0f;
117 threshold_2 = 256.0f * 3.0f;
118 threshold_4 = 256.0f * 32.0f;
119 } else if (mali_info.IsBifrostGen2()) {
120 threshold_1 = 256.0f * 2.0f;
121 threshold_2 = 256.0f * 8.0f;
122 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
123 threshold_1 = 256.0f;
124 threshold_2 = 256.0f * 8.0f;
125 } else if (mali_info.IsMidgard()) {
126 threshold_1 = 256.0f * 4.0f;
127 }
128 break;
129 case CalculationsPrecision::F32:
130 if (mali_info.IsBifrostGen1()) {
131 threshold_1 = 256.0f;
132 threshold_2 = 256.0f * 4.0f;
133 } else if (mali_info.IsBifrostGen2()) {
134 threshold_1 = 128.0f;
135 threshold_2 = 256.0f * 4.0f;
136 } else if (mali_info.IsBifrostGen3() || mali_info.IsValhall()) {
137 threshold_1 = 256.0f;
138 threshold_2 = 256.0f * 12.0f;
139 } else if (mali_info.IsMidgard()) {
140 threshold_1 = 256.0f * 16.0f;
141 }
142 break;
143 }
144 if (task_size_per_cu <= threshold_1) {
145 block_size = 1;
146 } else if (task_size_per_cu <= threshold_2) {
147 block_size = 2;
148 } else if (task_size_per_cu <= threshold_4) {
149 block_size = 4;
150 } else {
151 block_size = 8;
152 }
153 return block_size;
154 }
155
GetWorkGroupsCount(const int3 & grid_size,const int3 & work_group_size)156 int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) {
157 int3 work_groups_count;
158 work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x);
159 work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y);
160 work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z);
161 return work_groups_count;
162 }
163
164 } // namespace gpu
165 } // namespace tflite
166