• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_constants.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24 
25 namespace tflite {
26 namespace gpu {
27 
28 namespace {
29 // Adreno can provide up to ~3-4KB of constant memory, but in some cases even
30 // 3KB can have very bad performance.
GetAdrenoOptimalMaxConstantSize(const AdrenoInfo & adreno_info)31 int GetAdrenoOptimalMaxConstantSize(const AdrenoInfo& adreno_info) {
32   if (adreno_info.IsAdreno3xx() || adreno_info.IsAdreno4xx() ||
33       adreno_info.IsAdreno5xx()) {
34     return 256 * 10;  // 2.5KB
35   } else {
36     return 256 * 14;  // 3.5KB
37   }
38 }
39 
GetOptimalMaxConstantSize(const GpuInfo & info)40 int GetOptimalMaxConstantSize(const GpuInfo& info) {
41   if (!info.IsAdreno()) {
42     // In general we do not expect that this kernel will be used with non Adreno
43     // so as it tuned for __constant memory that have big profit on Adreno
44     return 1024;  // 1KB
45   } else {
46     return GetAdrenoOptimalMaxConstantSize(info.adreno_info);
47   }
48 }
49 
50 // src_size and dst_size must be <= 4;
GenerateConv(int src_size,int dst_size,bool use_dot_conv,int const_mem_offset,CalculationsPrecision precision,const std::string & dst,const std::string & src)51 std::string GenerateConv(int src_size, int dst_size, bool use_dot_conv,
52                          int const_mem_offset, CalculationsPrecision precision,
53                          const std::string& dst, const std::string& src) {
54   std::string result;
55   const std::string postfixes[] = {".x", ".y", ".z", ".w"};
56   if (use_dot_conv) {
57     const std::string src_postfixes[] = {".x", ".xy", ".xyz", ""};
58     const std::string src_postfix = src_postfixes[src_size - 1];
59     for (int i = 0; i < dst_size; ++i) {
60       result += "    " + dst + postfixes[i] + " += dot(" + src +
61                 ", constants[" + std::to_string(const_mem_offset + i) + "]" +
62                 src_postfix + ");\n";
63     }
64   } else {
65     const std::string dst_postfixes[] = {".x", ".xy", ".xyz", ""};
66     const std::string dst_postfix = dst_postfixes[dst_size - 1];
67     if (precision == CalculationsPrecision::F32_F16) {
68       for (int i = 0; i < src_size; ++i) {
69         if (i != 0) {
70           result += " + ";
71         }
72         std::string src_name = src;
73         if (src_size != 1) {
74           src_name += postfixes[i];
75         }
76         result += src_name + " * constants[" +
77                   std::to_string(const_mem_offset + i) + "]" + dst_postfix;
78       }
79       std::string size = dst_size == 1 ? "" : std::to_string(dst_size);
80       result = "    " + dst + dst_postfix + " += convert_float" + size + "(" +
81                result + ");\n";
82     } else {
83       for (int i = 0; i < src_size; ++i) {
84         std::string src_name = src;
85         if (src_size != 1) {
86           src_name += postfixes[i];
87         }
88         result += "    " + dst + dst_postfix + " += " + src_name +
89                   " * constants[" + std::to_string(const_mem_offset + i) + "]" +
90                   dst_postfix + ";\n";
91       }
92     }
93   }
94   return result;
95 }
96 
GenerateConvolutionConstantCode(const OperationDef & op_def,const OHWI & weights_shape,bool stride_correction,bool use_dot_conv,GPUOperation * op)97 std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
98                                             const OHWI& weights_shape,
99                                             bool stride_correction,
100                                             bool use_dot_conv,
101                                             GPUOperation* op) {
102   auto src_desc = op_def.src_tensors[0];
103   src_desc.SetAddressMode(AddressMode::kZero);
104   if (op_def.IsBatchSupported()) {
105     src_desc.SetStateVar("BatchedWidth", "true");
106   }
107   op->AddSrcTensor("src_tensor", src_desc);
108 
109   auto dst_desc = op_def.dst_tensors[0];
110   if (op_def.IsBatchSupported()) {
111     dst_desc.SetStateVar("BatchedWidth", "true");
112   }
113   op->AddDstTensor("dst_tensor", dst_desc);
114 
115   const int out_z = DivideRoundUp(weights_shape.o, 4);
116   const std::string kOutZ = std::to_string(out_z);
117   const int src_depth = DivideRoundUp(weights_shape.i, 4);
118 
119   const std::string postfixes[] = {".x", ".xy", ".xyz", ""};
120 
121   std::string c;
122   c += "__kernel void main_function(\n";
123   c += "$0) {\n";
124   c += "  int X = get_global_id(0);\n";
125   c += "  int Y = get_global_id(1);\n";
126   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
127        "return;\n";
128   if (stride_correction) {
129     c += "  int start_x = " +
130          GetXStrideCorrectedV2("X", "args.src_tensor.Batch()", "args.stride_x",
131                                "args.padding_x") +
132          ";\n";
133   } else {
134     if (op_def.IsBatchSupported()) {
135       c += "  int start_x = X * args.stride_x + args.padding_x * "
136            "args.src_tensor.Batch();\n";
137     } else {
138       c += "  int start_x = X * args.stride_x + args.padding_x;\n";
139     }
140   }
141   c += "  int start_y = Y * args.stride_y + args.padding_y;\n";
142   c += "  __constant FLT4* constants = args.weights.GetPtr();\n";
143   for (int i = 0; i < out_z; ++i) {
144     c += "  ACCUM_FLT4 r" + std::to_string(i) +
145          " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
146   }
147   auto generate_check = [&]() {
148     std::string check;
149     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
150     const std::vector<std::string> names{"x_out", "y_out", "z_out"};
151     for (int i = 0; i < axes.size(); ++i) {
152       const auto& axis = axes[i];
153       if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
154         if (!check.empty()) {
155           check += " || ";
156         }
157         check += names[i];
158       }
159     }
160     return check;
161   };
162   const std::string check = generate_check();
163   int filters_counter = 0;
164   for (int s = 0; s < src_depth; ++s) {
165     const int src_ch_count = std::min(4, weights_shape.i - s * 4);
166     const std::string s_count =
167         src_ch_count == 1 ? "" : std::to_string(src_ch_count);
168     const std::string s_type = absl::StrCat("FLT", s_count);
169     const std::string s_postfix = postfixes[src_ch_count - 1];
170     const std::string dilation_x =
171         op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
172                                   : "args.dilation_x";
173     for (int ky = 0; ky < weights_shape.h; ++ky) {
174       std::string s_y = absl::StrCat("(start_y + ", ky, " * args.dilation_y)");
175       if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
176         c += "  {\n";
177         c += "  bool y_out = " + s_y + " < 0 || " + s_y +
178              " >= args.src_tensor.Height();\n";
179       }
180       for (int kx = 0; kx < weights_shape.w; ++kx) {
181         c += "  {\n";
182         std::string s_x =
183             absl::StrCat("(start_x + ", kx, " * " + dilation_x + ")");
184         if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
185           c += "    bool x_out = " + s_x + " < 0 || " + s_x +
186                ">= args.src_tensor.Width();\n";
187         }
188         if (check.empty()) {
189           c += "    " + s_type + " src = args.src_tensor.Read(" + s_x + ", " +
190                s_y + ", " + std::to_string(s) + ")" + s_postfix + ";\n";
191         } else {
192           c += "    " + s_type + " src = x_out || y_out ? ";
193           c += "(" + s_type + ")(0.0) : args.src_tensor.Read(" + s_x + ", " +
194                s_y + ", " + std::to_string(s) + ")" + s_postfix + ";\n";
195         }
196         for (int d = 0; d < out_z; ++d) {
197           const int dst_ch_count = std::min(4, weights_shape.o - d * 4);
198           c += GenerateConv(src_ch_count, dst_ch_count, use_dot_conv,
199                             filters_counter, op_def.precision,
200                             "r" + std::to_string(d), "src");
201           filters_counter += use_dot_conv ? dst_ch_count : src_ch_count;
202         }
203         c += "  }\n";
204       }
205       if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
206         c += "  }\n";
207       }
208     }
209   }
210   for (int i = 0; i < out_z; ++i) {
211     std::string s_i = std::to_string(i);
212     c += "  {\n";
213     c += "    FLT4 res = TO_FLT4(r" + s_i + ") + args.biases.Read(" + s_i +
214          ");\n";
215     c += "    args.dst_tensor.Write(res, X, Y, " + s_i + ");\n";
216     c += "  }\n";
217   }
218   c += "}\n";
219   return c;
220 }
221 
IsDotConvBetter(int src_channels,int dst_channels)222 bool IsDotConvBetter(int src_channels, int dst_channels) {
223   if (dst_channels % 4 == 0) {
224     return false;
225   }
226 
227   // dst_channels % 4 != 0
228   if (src_channels % 4 == 0) {
229     return true;
230   }
231 
232   // dst_channels % 4 != 0 && src_channels % 4 != 0
233   const int src_depth = DivideRoundUp(src_channels, 4);
234   const int dst_depth = DivideRoundUp(dst_channels, 4);
235   return dst_channels * src_depth < src_channels * dst_depth;
236 }
237 
238 }  // namespace
239 
IsConvConstantsSupported(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr)240 bool IsConvConstantsSupported(const GpuInfo& gpu_info,
241                               const OperationDef& definition,
242                               const Convolution2DAttributes& attr) {
243   if (gpu_info.IsAMD() && definition.precision != CalculationsPrecision::F32 &&
244       definition.src_tensors[0].storage_type != TensorStorageType::BUFFER) {
245     // BUG, some AMD GPUs crash without it
246     return false;
247   }
248 
249   const bool use_dot_conv =
250       IsDotConvBetter(attr.weights.shape.i, attr.weights.shape.o);
251   const auto& w_shape = attr.weights.shape;
252   const int src_depth = DivideRoundUp(w_shape.i, 4);
253   const int dst_depth = DivideRoundUp(w_shape.o, 4);
254   const int aligned_ch_count =
255       use_dot_conv ? w_shape.o * src_depth * 4 : w_shape.i * dst_depth * 4;
256   const int filters_count = aligned_ch_count * w_shape.h * w_shape.w;
257   const int float_size = definition.precision == CalculationsPrecision::F32
258                              ? sizeof(float)
259                              : sizeof(half);
260   const int filters_buffer_size = filters_count * float_size;
261   const int kConstantMaxSize = GetOptimalMaxConstantSize(gpu_info);
262   const int flt4_registers = DivideRoundUp(w_shape.o, 4);
263   return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8;
264 }
265 
CreateConvConstants(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr)266 GPUOperation CreateConvConstants(const GpuInfo& gpu_info,
267                                  const OperationDef& definition,
268                                  const Convolution2DAttributes& attr) {
269   const bool use_dot_conv =
270       IsDotConvBetter(attr.weights.shape.i, attr.weights.shape.o);
271   GPUOperation op(definition);
272   UploadWeightsForConvConstants(attr.weights, definition.precision,
273                                 use_dot_conv, &op);
274   op.args_.AddInt("stride_x", attr.strides.w);
275   op.args_.AddInt("stride_y", attr.strides.h);
276   op.args_.AddInt("padding_x", -attr.padding.prepended.w);
277   op.args_.AddInt("padding_y", -attr.padding.prepended.h);
278   op.args_.AddInt("dilation_x", attr.dilations.w);
279   op.args_.AddInt("dilation_y", attr.dilations.h);
280   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_ZIs1;
281 
282   const bool stride_correction =
283       definition.IsBatchSupported() && attr.strides.w != 1;
284 
285   op.code_ = GenerateConvolutionConstantCode(
286       definition, attr.weights.shape, stride_correction, use_dot_conv, &op);
287   if (definition.precision == CalculationsPrecision::F16 &&
288       gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx()) {
289     op.compiler_options_.push_back(CompilerOptions::kAdrenoFullSimd);
290   }
291   if (definition.precision != CalculationsPrecision::F32 &&
292       gpu_info.IsPowerVR()) {
293     // BUG, some PowerVRs (GE8320) produce incorrect result without it
294     op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
295   }
296 
297   TensorLinearDescriptor desc;
298   desc.storage_type = LinearStorageType::BUFFER;
299   desc.element_type = definition.GetDataType();
300   desc.memory_type = MemoryType::CONSTANT;
301   desc.UploadLinearData(attr.bias);
302   op.args_.AddObject(
303       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
304   return op;
305 }
306 
307 }  // namespace gpu
308 }  // namespace tflite
309