• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_buffer_1x1.h"
17 
18 #include <array>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
25 
26 namespace tflite {
27 namespace gpu {
28 namespace {
29 
30 // element_size must be 1, 2 or 4
31 // 1 - is FLT4
32 // 2 - is FLT8
33 // 4 - is FLT16
34 // This function generates code for arithmetic part of convolution
GetComputationPart(const int3 & block_size,int element_size,CalculationsPrecision precision,const GpuInfo & gpu_info)35 std::string GetComputationPart(const int3& block_size, int element_size,
36                                CalculationsPrecision precision,
37                                const GpuInfo& gpu_info) {
38   std::string hexes[16];
39   if (gpu_info.IsApiOpenCl()) {
40     hexes[0] = ".s0";
41     hexes[1] = ".s1";
42     hexes[2] = ".s2";
43     hexes[3] = ".s3";
44     hexes[4] = ".s4";
45     hexes[5] = ".s5";
46     hexes[6] = ".s6";
47     hexes[7] = ".s7";
48     hexes[8] = ".s8";
49     hexes[9] = ".s9";
50     hexes[10] = ".sa";
51     hexes[11] = ".sb";
52     hexes[12] = ".sc";
53     hexes[13] = ".sd";
54     hexes[14] = ".se";
55     hexes[15] = ".sf";
56   } else if (gpu_info.IsApiMetal()) {
57     hexes[0] = "[0].x";
58     hexes[1] = "[0].y";
59     hexes[2] = "[0].z";
60     hexes[3] = "[0].w";
61     hexes[4] = "[1].x";
62     hexes[5] = "[1].y";
63     hexes[6] = "[1].z";
64     hexes[7] = "[1].w";
65     hexes[8] = "[2].x";
66     hexes[9] = "[2].y";
67     hexes[10] = "[2].z";
68     hexes[11] = "[2].w";
69     hexes[12] = "[3].x";
70     hexes[13] = "[3].y";
71     hexes[14] = "[3].z";
72     hexes[15] = "[3].w";
73     if (element_size == 1) {
74       hexes[0] = ".x";
75       hexes[1] = ".y";
76       hexes[2] = ".z";
77       hexes[3] = ".w";
78     }
79   }
80   std::string c;
81   for (int z = 0; z < block_size.z; ++z) {
82     const std::string z_s = std::to_string(z);
83     c += "    FLT16 W" + z_s + " = weights_cache[" + z_s + "];\n";
84     for (int y = 0; y < block_size.y; ++y) {
85       for (int x = 0; x < block_size.x; ++x) {
86         std::string s_index = std::to_string(y * block_size.x + x);
87         for (int e = 0; e < element_size; ++e) {
88           std::string r_index =
89               z_s + std::to_string(y) + std::to_string(x * element_size + e);
90           const std::string f0 = "FLT16_0123(W" + z_s + ")";
91           const std::string f1 = "FLT16_4567(W" + z_s + ")";
92           const std::string f2 = "FLT16_89ab(W" + z_s + ")";
93           const std::string f3 = "FLT16_cdef(W" + z_s + ")";
94           switch (precision) {
95             case CalculationsPrecision::F32:
96             case CalculationsPrecision::F16:
97               c += "    r" + r_index + " += " + f0 + " * s" + s_index +
98                    hexes[e * 4 + 0] + ";\n";
99               c += "    r" + r_index + " += " + f1 + " * s" + s_index +
100                    hexes[e * 4 + 1] + ";\n";
101               c += "    r" + r_index + " += " + f2 + " * s" + s_index +
102                    hexes[e * 4 + 2] + ";\n";
103               c += "    r" + r_index + " += " + f3 + " * s" + s_index +
104                    hexes[e * 4 + 3] + ";\n";
105               break;
106             case CalculationsPrecision::F32_F16:
107               c += "    r" + r_index + " += TO_ACCUM_TYPE(" + f0 + " * s" +
108                    s_index + hexes[e * 4 + 0] + " + " + f1 + " * s" + s_index +
109                    hexes[e * 4 + 1] + " + " + f2 + " * s" + s_index +
110                    hexes[e * 4 + 2] + " + " + f3 + " * s" + s_index +
111                    hexes[e * 4 + 3] + ");\n";
112               break;
113           }
114         }
115       }
116     }
117   }
118   return c;
119 }
120 
GetBestParams(const GpuInfo & gpu_info,const OperationDef & definition,const BHWC & shape,int src_depth,int dst_depth)121 ConvBuffer1x1::ConvParams GetBestParams(const GpuInfo& gpu_info,
122                                         const OperationDef& definition,
123                                         const BHWC& shape, int src_depth,
124                                         int dst_depth) {
125   ConvBuffer1x1::ConvParams conv_params;
126   conv_params.element_size = 4;
127   conv_params.block_size = int3(1, 1, 1);
128   if (!gpu_info.IsMali()) {
129     return conv_params;
130   }
131   bool can_use_flt8 = (shape.w * shape.b) % 2 == 0 &&
132                       definition.precision != CalculationsPrecision::F32;
133   bool is_midgard = gpu_info.IsMali() && gpu_info.mali_info.IsMidgard();
134   if (is_midgard) {
135     if (can_use_flt8) {
136       conv_params.element_size = 8;
137     }
138     if (definition.precision == CalculationsPrecision::F16 || !can_use_flt8) {
139       conv_params.block_size.x = 2;
140     }
141     return conv_params;
142   }
143 
144   int task_size = shape.w * shape.b * shape.h * dst_depth;
145   int block_size =
146       GetRecommendedBlockSizeForConv(gpu_info, definition.precision, task_size);
147 
148   if (!can_use_flt8 && block_size > 4) {
149     block_size = 4;
150   }
151 
152   if (can_use_flt8 && block_size >= 2) {
153     conv_params.element_size = 8;
154     block_size /= 2;
155   }
156   if (block_size == 4) {
157     conv_params.block_size.x = 2;
158     if (definition.precision == CalculationsPrecision::F32 && dst_depth < 32) {
159       conv_params.block_size.y = 2;
160     } else {
161       conv_params.block_size.z = 2;
162     }
163   } else if (block_size == 2) {
164     if (dst_depth >= 32) {
165       conv_params.block_size.z = 2;
166     } else {
167       conv_params.block_size.x = 2;
168     }
169   }
170 
171   return conv_params;
172 }
173 
GetBestParams(const GpuInfo & gpu_info,const OperationDef & definition,int src_depth,int dst_depth)174 ConvBuffer1x1::ConvParams GetBestParams(const GpuInfo& gpu_info,
175                                         const OperationDef& definition,
176                                         int src_depth, int dst_depth) {
177   ConvBuffer1x1::ConvParams conv_params;
178   conv_params.element_size = 4;
179   conv_params.block_size = int3(1, 1, 1);
180   if (gpu_info.IsMali() && definition.precision == CalculationsPrecision::F16 &&
181       gpu_info.GetComputeUnitsCount() <= 4) {
182     conv_params.block_size.x *= 2;
183   }
184   return conv_params;
185 }
186 
187 }  // namespace
188 
ConvBuffer1x1(const OperationDef & definition,const ConvParams & conv_params,const GpuInfo & gpu_info)189 ConvBuffer1x1::ConvBuffer1x1(const OperationDef& definition,
190                              const ConvParams& conv_params,
191                              const GpuInfo& gpu_info)
192     : GPUOperation(definition), conv_params_(conv_params) {
193   code_ = GenerateConvBuffer1x1(definition_, conv_params_, gpu_info, &args_);
194   work_group_size_ = int3(2, 4, 1);
195 }
196 
ConvBuffer1x1(ConvBuffer1x1 && operation)197 ConvBuffer1x1::ConvBuffer1x1(ConvBuffer1x1&& operation)
198     : GPUOperation(std::move(operation)),
199       conv_params_(std::move(operation.conv_params_)) {}
200 
operator =(ConvBuffer1x1 && operation)201 ConvBuffer1x1& ConvBuffer1x1::operator=(ConvBuffer1x1&& operation) {
202   if (this != &operation) {
203     std::swap(conv_params_, operation.conv_params_);
204     GPUOperation::operator=(std::move(operation));
205   }
206   return *this;
207 }
208 
GenerateConvBuffer1x1(const OperationDef & op_def,const ConvBuffer1x1::ConvParams & conv_params,const GpuInfo & gpu_info,Arguments * args)209 std::string ConvBuffer1x1::GenerateConvBuffer1x1(
210     const OperationDef& op_def, const ConvBuffer1x1::ConvParams& conv_params,
211     const GpuInfo& gpu_info, Arguments* args) {
212   auto src_desc = op_def.src_tensors[0];
213   if (op_def.IsBatchSupported()) {
214     src_desc.SetStateVar("BatchedWidth", "true");
215   }
216   if (conv_params_.element_size == 8) {
217     src_desc.SetStateVar("ElementsX2", "true");
218   } else if (conv_params_.element_size == 16) {
219     src_desc.SetStateVar("ElementsX4", "true");
220   }
221   AddSrcTensor("src_tensor", src_desc);
222   if (op_def.src_tensors.size() == 2) {
223     // dynamic weights
224     BufferDescriptor desc;
225     desc.element_type = op_def.src_tensors[1].data_type;
226     desc.element_size = 16;
227     desc.memory_type = MemoryType::GLOBAL;
228     AddSrcBuffer("weights", desc);
229   }
230 
231   auto dst_desc = op_def.dst_tensors[0];
232   if (op_def.IsBatchSupported()) {
233     dst_desc.SetStateVar("BatchedWidth", "true");
234   }
235   AddDstTensor("dst_tensor", dst_desc);
236 
237   if (gpu_info.IsMali()) {
238     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
239   }
240 
241   std::string c;
242   switch (op_def.precision) {
243     case CalculationsPrecision::F32:
244       c += "#define FLT8 float8\n";
245       c += "#define FLT16 float16\n";
246       break;
247     case CalculationsPrecision::F32_F16:
248     case CalculationsPrecision::F16:
249       c += "#define FLT8 half8\n";
250       c += "#define FLT16 half16\n";
251       break;
252   }
253 
254   const int3 block_size = conv_params.block_size;
255   const int element_size = conv_params.element_size / 4;
256 
257   c += "MAIN_FUNCTION($0) {\n";
258   c += "  int X = GLOBAL_ID_0 * " +
259        std::to_string(block_size.x * element_size) + ";\n";
260   c += "  int X_SRC = GLOBAL_ID_0 * " + std::to_string(block_size.x) + ";\n";
261   c += "  int Y = GLOBAL_ID_1 * " + std::to_string(block_size.y) + ";\n";
262   c += "  int Z = GLOBAL_ID_2 * " + std::to_string(block_size.z) + ";\n";
263   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
264        "Z >= args.dst_tensor.Slices()) return;\n";
265   if (conv_params.different_weights_for_height) {
266     c += "  __global FLT16* weights_cache = args.weights.GetPtr() + (Z * "
267          "args.src_tensor.Height() + "
268          "Y * " +
269          std::to_string(block_size.z) +
270          ") * "
271          "args.src_tensor.Slices();\n";
272   } else {
273     c += "  __global FLT16* weights_cache = args.weights.GetPtr() + Z * "
274          "args.src_tensor.Slices();\n";
275   }
276   for (int z = 0; z < block_size.z; ++z) {
277     const std::string z_s = std::to_string(z);
278     c += "  ACCUM_FLT4 bias_val_" + z_s +
279          " = TO_ACCUM_TYPE(args.biases.Read(Z + " + z_s + "));\n";
280     for (int y = 0; y < block_size.y; ++y) {
281       for (int x = 0; x < block_size.x * element_size; ++x) {
282         c += "  ACCUM_FLT4 r" + z_s + std::to_string(y) + std::to_string(x) +
283              " = bias_val_" + z_s + ";\n";
284       }
285     }
286   }
287   for (int x = 0; x < block_size.x; ++x) {
288     std::string x_s = std::to_string(x);
289     c += "  int xc" + x_s + " = min(X_SRC + " + std::to_string(x) +
290          ", args.src_tensor.Width() - 1);\n";
291   }
292   for (int y = 0; y < block_size.y; ++y) {
293     std::string y_s = std::to_string(y);
294     c += "  int yc" + y_s + " = min(Y + " + y_s +
295          ", args.src_tensor.Height() - 1);\n";
296   }
297   for (int y = 0; y < block_size.y; ++y) {
298     std::string y_s = std::to_string(y);
299     for (int x = 0; x < block_size.x; ++x) {
300       std::string x_s = std::to_string(x);
301       std::string i_s = std::to_string(y * block_size.x + x);
302       c += "  int src_addr_" + i_s + " = (yc" + y_s +
303            ") * args.src_tensor.Width() + (xc" + x_s + ");\n";
304     }
305   }
306   c += "  for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
307   for (int y = 0; y < block_size.y; ++y) {
308     std::string y_s = std::to_string(y);
309     for (int x = 0; x < block_size.x; ++x) {
310       std::string x_s = std::to_string(x);
311       std::string i_s = std::to_string(y * block_size.x + x);
312       c += "    FLT" + std::to_string(element_size * 4) + " s" + i_s +
313            " = args.src_tensor.Read(src_addr_" + i_s + ");\n";
314     }
315   }
316   c += GetComputationPart(block_size, element_size, op_def.precision, gpu_info);
317   for (int i = 0; i < block_size.x * block_size.y; ++i) {
318     std::string i_s = std::to_string(i);
319     c += "    src_addr_" + i_s + " += args.src_tensor.SliceStride();\n";
320   }
321   c += "    weights_cache += " + std::to_string(block_size.z) + ";\n";
322   c += "  }\n";  // SRC_SLICES
323 
324   for (int z = 0; z < block_size.z; ++z) {
325     const std::string z_s = std::to_string(z);
326     if (z != 0) {
327       c += "  if (Z + " + z_s + " >= args.dst_tensor.Slices()) return;\n";
328     }
329     for (int y = 0; y < block_size.y; ++y) {
330       const std::string y_s = std::to_string(y);
331       for (int x = 0; x < block_size.x * element_size; ++x) {
332         const std::string x_s = std::to_string(x);
333         c += "  if (X + " + x_s + " < args.dst_tensor.Width() && Y + " + y_s +
334              " < args.dst_tensor.Height()) {\n";
335         c += "    FLT4 res = TO_FLT4(r" + z_s + y_s + x_s + ");\n";
336         c += "    args.dst_tensor.Write(res, X + " + x_s + ", Y + " + y_s +
337              ", Z + " + z_s + ");\n";
338         c += "  }\n";
339       }
340     }
341   }
342   c += "}\n";
343   return c;
344 }
345 
GetGridSize() const346 int3 ConvBuffer1x1::GetGridSize() const {
347   const int dst_width_elements = DivideRoundUp(
348       dst_[0]->Width() * dst_[0]->Batch(), (conv_params_.element_size / 4));
349   const int grid_x =
350       DivideRoundUp(dst_width_elements, conv_params_.block_size.x);
351   const int grid_y =
352       DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
353   const int grid_z =
354       DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.z);
355   return int3(grid_x, grid_y, grid_z);
356 }
357 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const358 void ConvBuffer1x1::GetPossibleKernelWorkGroups(
359     TuningType tuning_type, const GpuInfo& gpu_info,
360     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
361   GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
362                             work_groups);
363 }
364 
IsConvBuffer1x1Supported(const OperationDef & definition,const Convolution2DAttributes & attr)365 bool IsConvBuffer1x1Supported(const OperationDef& definition,
366                               const Convolution2DAttributes& attr) {
367   auto src_storage_type = definition.src_tensors[0].storage_type;
368   return src_storage_type == TensorStorageType::BUFFER &&
369          attr.weights.shape.w == 1 && attr.weights.shape.h == 1 &&
370          attr.dilations.w == 1 && attr.dilations.h == 1 &&
371          attr.strides.w == 1 && attr.strides.h == 1 &&
372          attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0 &&
373          attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
374 }
375 
IsConvBuffer1x1Supported(const OperationDef & definition,const BHWC & weights_shape,const Convolution2DAttributes & attr)376 bool IsConvBuffer1x1Supported(const OperationDef& definition,
377                               const BHWC& weights_shape,
378                               const Convolution2DAttributes& attr) {
379   auto src_storage_type = definition.src_tensors[0].storage_type;
380   return src_storage_type == TensorStorageType::BUFFER &&
381          weights_shape.w == 1 && weights_shape.h == 1 &&
382          attr.dilations.w == 1 && attr.dilations.h == 1 &&
383          attr.strides.w == 1 && attr.strides.h == 1 &&
384          attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0 &&
385          attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
386 }
387 
CreateConvBuffer1x1(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC * shape)388 ConvBuffer1x1 CreateConvBuffer1x1(const GpuInfo& gpu_info,
389                                   const OperationDef& definition,
390                                   const Convolution2DAttributes& attr,
391                                   const BHWC* shape) {
392   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
393   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
394   ConvBuffer1x1::ConvParams conv_params;
395   if (shape) {
396     conv_params =
397         GetBestParams(gpu_info, definition, *shape, src_depth, dst_depth);
398   } else {
399     conv_params = GetBestParams(gpu_info, definition, src_depth, dst_depth);
400   }
401   ConvBuffer1x1 result(definition, conv_params, gpu_info);
402   result.UploadData(attr.weights, attr.bias);
403   return result;
404 }
405 
CreateConvBuffer1x1(const GpuInfo & gpu_info,const OperationDef & definition,const FullyConnectedAttributes & attr,const BHWC * shape)406 ConvBuffer1x1 CreateConvBuffer1x1(const GpuInfo& gpu_info,
407                                   const OperationDef& definition,
408                                   const FullyConnectedAttributes& attr,
409                                   const BHWC* shape) {
410   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
411   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
412   ConvBuffer1x1::ConvParams conv_params;
413   if (shape) {
414     conv_params =
415         GetBestParams(gpu_info, definition, *shape, src_depth, dst_depth);
416   } else {
417     conv_params = GetBestParams(gpu_info, definition, src_depth, dst_depth);
418   }
419   conv_params.block_size.x *= conv_params.block_size.y;
420   conv_params.block_size.y = 1;
421   ConvBuffer1x1 result(definition, conv_params, gpu_info);
422   result.UploadData(attr.weights, attr.bias);
423   return result;
424 }
425 
CreateConvBuffer1x1Wino4x4To6x6(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC * shape)426 ConvBuffer1x1 CreateConvBuffer1x1Wino4x4To6x6(
427     const GpuInfo& gpu_info, const OperationDef& definition,
428     const Convolution2DAttributes& attr, const BHWC* shape) {
429   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
430   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
431   ConvBuffer1x1::ConvParams conv_params;
432   if (shape) {
433     conv_params =
434         GetBestParams(gpu_info, definition, *shape, src_depth, dst_depth);
435   } else {
436     conv_params = GetBestParams(gpu_info, definition, src_depth, dst_depth);
437   }
438   conv_params.block_size.x *= conv_params.block_size.y;
439   conv_params.block_size.y = 1;
440   conv_params.different_weights_for_height = true;
441   ConvBuffer1x1 result(definition, conv_params, gpu_info);
442   result.UploadDataForWinograd4x4To6x6(attr.weights);
443   return result;
444 }
445 
CreateConvBuffer1x1DynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC & weights_shape,const BHWC * dst_shape)446 ConvBuffer1x1 CreateConvBuffer1x1DynamicWeights(
447     const GpuInfo& gpu_info, const OperationDef& definition,
448     const Convolution2DAttributes& attr, const BHWC& weights_shape,
449     const BHWC* dst_shape) {
450   const int dst_depth = DivideRoundUp(weights_shape.b, 4);
451   const int src_depth = DivideRoundUp(weights_shape.c, 4);
452   ConvBuffer1x1::ConvParams conv_params;
453   if (dst_shape) {
454     conv_params =
455         GetBestParams(gpu_info, definition, *dst_shape, src_depth, dst_depth);
456   } else {
457     conv_params = GetBestParams(gpu_info, definition, src_depth, dst_depth);
458   }
459   ConvBuffer1x1 result(definition, conv_params, gpu_info);
460   result.UploadBiases(attr.bias);
461   return result;
462 }
463 
464 }  // namespace gpu
465 }  // namespace tflite
466