• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/shape.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
28 
29 namespace tflite {
30 namespace gpu {
31 
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposedAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)32 ConvolutionTransposed::ConvolutionTransposed(
33     const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
34     const GpuInfo& gpu_info, bool weights_are_buffer)
35     : GPUOperation(definition),
36       stride_(attr.stride.w, attr.stride.h, 1, 1),
37       block_size_(2, 2, 1, 2) {
38   if (weights_are_buffer) {
39     if (gpu_info.IsApple()) {
40       weights_layout_ = WeightsLayout::kOSpatialIOGroupO4I4;
41     } else {
42       weights_layout_ = WeightsLayout::kOSpatialIOGroupI4O4;
43     }
44   } else {
45     if (gpu_info.IsApple()) {
46       weights_layout_ = WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4;
47     } else {
48       weights_layout_ = WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4;
49     }
50   }
51   const bool is_f16 = definition.precision == CalculationsPrecision::F16;
52   if (gpu_info.IsMali()) {
53     if (gpu_info.mali_info.IsMidgard()) {
54       block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
55     } else {
56       block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
57     }
58   }
59   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
60   if (dst_depth == 1 || dst_depth == 3) {
61     if (!gpu_info.IsMali()) {
62       block_size_.y *= block_size_.w;
63     }
64     block_size_.w = 1;
65   }
66 
67   args_.AddInt("stride_x", stride_.x);
68   args_.AddInt("stride_y", stride_.y);
69   args_.AddInt("padding_x", attr.padding.prepended.w);
70   args_.AddInt("padding_y", attr.padding.prepended.h);
71   args_.AddInt("kernel_size_x", attr.weights.shape.w);
72   args_.AddInt("kernel_size_y", attr.weights.shape.h);
73   code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
74                                             weights_are_buffer, block_size_);
75 }
76 
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)77 ConvolutionTransposed::ConvolutionTransposed(
78     const OperationDef& definition,
79     const ConvolutionTransposed3DAttributes& attr, const GpuInfo& gpu_info,
80     bool weights_are_buffer)
81     : GPUOperation(definition),
82       stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1),
83       block_size_(2, 2, 1, 2) {
84   if (weights_are_buffer) {
85     if (gpu_info.IsApple()) {
86       weights_layout_ = WeightsLayout::kOSpatialIOGroupO4I4;
87     } else {
88       weights_layout_ = WeightsLayout::kOSpatialIOGroupI4O4;
89     }
90   } else {
91     if (gpu_info.IsApple()) {
92       weights_layout_ = WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4;
93     } else {
94       weights_layout_ = WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4;
95     }
96   }
97   const bool is_f16 = definition.precision == CalculationsPrecision::F16;
98   if (gpu_info.IsMali()) {
99     if (gpu_info.mali_info.IsMidgard()) {
100       block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
101     } else {
102       block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
103     }
104   }
105   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
106   if (dst_depth == 1 || dst_depth == 3) {
107     if (!gpu_info.IsMali()) {
108       block_size_.y *= block_size_.w;
109     }
110     block_size_.w = 1;
111   }
112 
113   args_.AddInt("stride_x", stride_.x);
114   args_.AddInt("stride_y", stride_.y);
115   args_.AddInt("stride_z", stride_.z);
116   args_.AddInt("padding_x", attr.padding.prepended.w);
117   args_.AddInt("padding_y", attr.padding.prepended.h);
118   args_.AddInt("padding_z", attr.padding.prepended.d);
119   args_.AddInt("kernel_size_x", attr.weights.shape.w);
120   args_.AddInt("kernel_size_y", attr.weights.shape.h);
121   args_.AddInt("kernel_size_z", attr.weights.shape.d);
122   args_.AddInt("grid_size_y");
123   code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
124                                             weights_are_buffer, block_size_);
125 }
126 
GenerateConvolutionTransposedCode(const OperationDef & op_def,const GpuInfo & gpu_info,bool weights_are_buffer,const int4 & block_size)127 std::string ConvolutionTransposed::GenerateConvolutionTransposedCode(
128     const OperationDef& op_def, const GpuInfo& gpu_info,
129     bool weights_are_buffer, const int4& block_size) {
130   auto src_desc = op_def.src_tensors[0];
131   src_desc.SetAddressMode(AddressMode::kZero);
132   AddSrcTensor("src_tensor", src_desc);
133   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
134 
135   if (op_def.src_tensors.size() != 1) {
136     // dynamic weights
137     if (weights_layout_ == WeightsLayout::kOSpatialIOGroupI4O4 ||
138         weights_layout_ == WeightsLayout::kOSpatialIOGroupO4I4) {
139       BufferDescriptor desc;
140       desc.element_type = op_def.src_tensors[1].data_type;
141       desc.element_size = 16;
142       desc.memory_type = MemoryType::GLOBAL;
143       AddSrcBuffer("weights", desc);
144     } else {
145       for (int i = 0; i < 4; ++i) {
146         Texture2DDescriptor desc;
147         desc.element_type = op_def.src_tensors[1 + i].data_type;
148         const std::string name = "weights" + std::to_string(i);
149         AddSrcTexture2D("weights" + std::to_string(i), desc);
150       }
151     }
152   }
153 
154   const auto& src_def = op_def.src_tensors[0];
155 
156   std::string c;
157 
158   for (int s = 0; s < block_size.w; ++s) {
159     std::string f0, f1, f2, f3;
160     if (weights_are_buffer) {
161       if (gpu_info.SupportsPointersInKernels()) {
162         f0 = "FLT16_0123(weights_cache[" + std::to_string(s) + "])";
163         f1 = "FLT16_4567(weights_cache[" + std::to_string(s) + "])";
164         f2 = "FLT16_89ab(weights_cache[" + std::to_string(s) + "])";
165         f3 = "FLT16_cdef(weights_cache[" + std::to_string(s) + "])";
166       } else {
167         f0 = "FLT16_0123(flt16val)";
168         f1 = "FLT16_4567(flt16val)";
169         f2 = "FLT16_89ab(flt16val)";
170         f3 = "FLT16_cdef(flt16val)";
171       }
172     } else {
173       f0 = "f" + std::to_string(s * 4 + 0);
174       f1 = "f" + std::to_string(s * 4 + 1);
175       f2 = "f" + std::to_string(s * 4 + 2);
176       f3 = "f" + std::to_string(s * 4 + 3);
177     }
178     if (GetWeightsDescription().IsI4O4()) {
179       switch (op_def.precision) {
180         case CalculationsPrecision::F32:
181         case CalculationsPrecision::F16:
182           c += "#define CONV" + std::to_string(s) + "(R, S)    \\\n";
183           c += "R += S.x * " + f0 + "; \\\n";
184           c += "R += S.y * " + f1 + "; \\\n";
185           c += "R += S.z * " + f2 + "; \\\n";
186           c += "R += S.w * " + f3 + ";   \n";
187           break;
188         case CalculationsPrecision::F32_F16:
189           c += "#define CONV" + std::to_string(s) + "(R, S) \\\n";
190           c += "R += TO_ACCUM_TYPE(S.x * " + f0 + " + S.y * " + f1 +
191                " + S.z * " + f2 + " + S.w * " + f3 + ");\n";
192           break;
193       }
194     } else {
195       // O4I4
196       c += "#define CONV" + std::to_string(s) + "(R, S)    \\\n";
197       c += "R.x += dot(S, " + f0 + "); \\\n";
198       c += "R.y += dot(S, " + f1 + "); \\\n";
199       c += "R.z += dot(S, " + f2 + "); \\\n";
200       c += "R.w += dot(S, " + f3 + ");   \n";
201     }
202   }
203 
204   auto generate_id = [&](const std::string& x, const std::string& y,
205                          const std::string& z) {
206     std::string id;
207     if (src_def.HasAxis(Axis::WIDTH)) {
208       id += "_w" + x;
209     }
210     if (src_def.HasAxis(Axis::HEIGHT)) {
211       id += "_h" + y;
212     }
213     if (src_def.HasAxis(Axis::DEPTH)) {
214       id += "_d" + z;
215     }
216     return id;
217   };
218 
219   auto generate_id_full = [&](const std::string& x, const std::string& y,
220                               const std::string& z, const std::string& s) {
221     return generate_id(x, y, z) + "_s" + s;
222   };
223 
224   auto generate_check = [&](const std::string& x, const std::string& y,
225                             const std::string& z) {
226     std::string check;
227     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
228     const std::vector<std::string> names{"in_x", "in_y", "in_z"};
229     const std::vector<std::string> coords{x, y, z};
230     for (int i = 0; i < axes.size(); ++i) {
231       const auto& axis = axes[i];
232       if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) &&
233           block_size[i] != 1) {
234         if (!check.empty()) {
235           check += " && ";
236         }
237         check += names[i] + coords[i];
238       }
239     }
240     return check;
241   };
242 
243   switch (op_def.precision) {
244     case CalculationsPrecision::F32:
245       c += "#define FLT16 float16\n";
246       break;
247     case CalculationsPrecision::F32_F16:
248     case CalculationsPrecision::F16:
249       c += "#define FLT16 half16\n";
250       break;
251   }
252 
253   c += "MAIN_FUNCTION($0) {\n";
254   if (op_def.IsBatchSupported()) {
255     c += "  int linear_id = GLOBAL_ID_0;\n";
256     c += "  int dst_x = (linear_id / args.dst_tensor.Batch());\n";
257     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
258     c += "  args.dst_tensor.SetBatchRef(B);\n";
259     c += "  args.src_tensor.SetBatchRef(B);\n";
260   } else {
261     c += "  int dst_x = GLOBAL_ID_0;\n";
262   }
263   c += "  int rem_x = dst_x % args.stride_x;\n";
264   c += "  int ceil_x = dst_x / args.stride_x;\n";
265   c += "  dst_x = ceil_x * args.stride_x * " + std::to_string(block_size.x) +
266        " + rem_x;\n";
267   if (src_def.HasAxis(Axis::DEPTH)) {
268     c += "  int linear_id_y = GLOBAL_ID_1;\n";
269     c += "  int dst_y = linear_id_y % args.grid_size_y;\n";
270     c += "  int dst_z = linear_id_y / args.grid_size_y;\n";
271     c += "  int rem_z = dst_z % args.stride_z;\n";
272     c += "  int ceil_z = dst_z / args.stride_z;\n";
273     c += "  dst_z = ceil_z * args.stride_z * " + std::to_string(block_size.z) +
274          " + rem_z;\n";
275     c += "  if (dst_z >= args.dst_tensor.Depth()) return;\n";
276   } else {
277     c += "  int dst_y = GLOBAL_ID_1;\n";
278   }
279   c += "  int rem_y = dst_y % args.stride_y;\n";
280   c += "  int ceil_y = dst_y / args.stride_y;\n";
281   c += "  dst_y = ceil_y * args.stride_y * " + std::to_string(block_size.y) +
282        " + rem_y;\n";
283   c += "  int dst_s = GLOBAL_ID_2 * " + std::to_string(block_size.w) + ";\n";
284   c += "  if (dst_x >= args.dst_tensor.Width() || dst_y >= "
285        "args.dst_tensor.Height() || dst_s >= "
286        "args.dst_tensor.Slices()) return;\n";
287   if (weights_are_buffer) {
288     c += "  int f_base = dst_s * args.src_tensor.Slices() * args.kernel_size_x "
289          "* args.kernel_size_y";
290     if (src_def.HasAxis(Axis::DEPTH)) {
291       c += " * args.kernel_size_z";
292     }
293     c += ";\n";
294   }
295   for (int s = 0; s < block_size.w; ++s) {
296     const std::string sind = std::to_string(s);
297     for (int z = 0; z < block_size.z; ++z) {
298       const std::string zind = std::to_string(z);
299       for (int y = 0; y < block_size.y; ++y) {
300         const std::string yind = std::to_string(y);
301         for (int x = 0; x < block_size.x; ++x) {
302           const std::string xind = std::to_string(x);
303           c += "  ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
304                " = INIT_ACCUM_FLT4(0.0f);\n";
305         }
306       }
307     }
308   }
309   c += "  int kernel_first_dst_x = dst_x + args.padding_x;\n";
310   c += "  int kernel_first_dst_y = dst_y + args.padding_y;\n";
311   c += "  int kernel_last_dst_x = kernel_first_dst_x - args.kernel_size_x;\n";
312   c += "  int kernel_last_dst_y = kernel_first_dst_y - args.kernel_size_y;\n";
313   c += "  int offset_x = abs(args.padding_x);\n";
314   c += "  int offset_x_strided = offset_x * args.stride_x;\n";
315   c +=
316       "  int src_x = (kernel_first_dst_x + offset_x_strided) / args.stride_x - "
317       "offset_x;\n";
318   c += "  int offset_y = abs(args.padding_y);\n";
319   c += "  int offset_y_strided = offset_y * args.stride_y;\n";
320   c +=
321       "  int src_y = (kernel_first_dst_y + offset_y_strided) / args.stride_y - "
322       "offset_y;\n";
323   if (src_def.HasAxis(Axis::DEPTH)) {
324     c += "  int kernel_first_dst_z = dst_z + args.padding_z;\n";
325     c += "  int kernel_last_dst_z = kernel_first_dst_z - args.kernel_size_z;\n";
326     c += "  int offset_z = abs(args.padding_z);\n";
327     c += "  int offset_z_strided = offset_z * args.stride_z;\n";
328     c += "  int src_z = (kernel_first_dst_z + offset_z_strided) / "
329          "args.stride_z - offset_z;\n";
330     c += "  int src_as_dst_z = src_z * args.stride_z;\n";
331     c +=
332         "  for (;src_as_dst_z > kernel_last_dst_z; src_z -= 1, src_as_dst_z -= "
333         "args.stride_z) {\n";
334     for (int z = 0; z < block_size.z; ++z) {
335       const std::string zindex = std::to_string(z);
336       c += "    int sz" + zindex + " = src_z + " + zindex + ";\n";
337       if (!src_def.SupportsZeroClamp(Axis::DEPTH)) {
338         c += "    bool in_z" + zindex + " = sz" + zindex + " >= 0 && sz" +
339              zindex + " < args.src_tensor.Depth();\n";
340         if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) {
341           c += "    sz" + zindex + " = clamp(sz" + zindex +
342                ", 0, args.src_tensor.Depth() - 1);\n";
343         }
344       }
345     }
346     if (block_size.z == 1 && !src_def.SupportsZeroClamp(Axis::DEPTH)) {
347       c += "    if (!in_z0) continue;\n";
348     }
349     c += "    int kernel_z = kernel_first_dst_z - src_as_dst_z;\n";
350     c += "    int src_as_dst_y = src_y * args.stride_y;\n";
351     c += "    int src_y_copy = src_y;\n";
352     c += "    for (;src_as_dst_y > kernel_last_dst_y; src_y_copy -= 1, "
353          "src_as_dst_y -= args.stride_y) {\n";
354   } else {
355     c += "  int src_as_dst_y = src_y * args.stride_y;\n";
356     c += "  for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y "
357          "-= args.stride_y) {\n";
358   }
359   for (int y = 0; y < block_size.y; ++y) {
360     const std::string yindex = std::to_string(y);
361     const std::string src_y =
362         src_def.HasAxis(Axis::DEPTH) ? "src_y_copy" : "src_y";
363     c += "    int sy" + yindex + " = " + src_y + " + " + yindex + ";\n";
364     if (!src_def.SupportsZeroClamp(Axis::HEIGHT)) {
365       c += "    bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" +
366            yindex + " < args.src_tensor.Height();\n";
367       if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) {
368         c += "    sy" + yindex + " = clamp(sy" + yindex +
369              ", 0, args.src_tensor.Height() - 1);\n";
370       }
371     }
372   }
373   if (block_size.y == 1 && !src_def.SupportsZeroClamp(Axis::HEIGHT)) {
374     c += "      if (!in_y0) continue;\n";
375   }
376   c += "    int kernel_y = kernel_first_dst_y - src_as_dst_y;\n";
377   c += "    int src_as_dst_x = src_x * args.stride_x;\n";
378   c += "    int src_x_copy = src_x;\n";
379   c += "    for (;src_as_dst_x > kernel_last_dst_x; src_x_copy -= 1, "
380        "src_as_dst_x "
381        "-= args.stride_x) {\n";
382   for (int x = 0; x < block_size.x; ++x) {
383     const std::string xindex = std::to_string(x);
384     c += "      int sx" + xindex + " = src_x_copy + " + xindex + ";\n";
385     if (!src_def.SupportsZeroClamp(Axis::WIDTH)) {
386       c += "      bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" +
387            xindex + " < args.src_tensor.Width();\n";
388       if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) {
389         c += "      sx" + xindex + " = clamp(sx" + xindex +
390              ", 0, args.src_tensor.Width() - 1);\n";
391       }
392     }
393   }
394   if (block_size.x == 1 && !src_def.SupportsZeroClamp(Axis::WIDTH)) {
395     c += "      if (!in_x0) continue;\n";
396   }
397   for (int z = 0; z < block_size.z; ++z) {
398     const std::string zind = std::to_string(z);
399     for (int y = 0; y < block_size.y; ++y) {
400       const std::string yind = std::to_string(y);
401       for (int x = 0; x < block_size.x; ++x) {
402         const std::string xind = std::to_string(x);
403         const std::string id = generate_id(xind, yind, zind);
404         const std::string check = generate_check(xind, yind, zind);
405         std::string coords = "sx" + xind + ", sy" + yind;
406         if (src_def.HasAxis(Axis::DEPTH)) {
407           coords += ", sz" + zind;
408         }
409         if (src_def.IsLinear()) {
410           c += "      args.src_tensor.GetAddress(addr" + id + ", " + coords +
411                ", 0);\n";
412         }
413         if (src_def.ReturnsZeroForNegOneRead()) {
414           c += "      addr" + id + " = select(-1, addr" + id + ", (" + check +
415                "));\n";
416           c += "      int ds" + id +
417                " = select(0, args.src_tensor.SliceStride(), (" + check +
418                "));\n";
419         }
420       }
421     }
422   }
423   if (src_def.storage_type == TensorStorageType::BUFFER) {
424     c += "      int ds = args.src_tensor.SliceStride();\n";
425   }
426   c += "      int kernel_x = kernel_first_dst_x - src_as_dst_x;\n";
427   if (src_def.HasAxis(Axis::DEPTH)) {
428     c += "      int kernel_index = (kernel_z * args.kernel_size_y + kernel_y) "
429          "*  args.kernel_size_x + kernel_x;\n";
430   } else {
431     c += "      int kernel_index = kernel_y * args.kernel_size_x + kernel_x;\n";
432   }
433   if (weights_are_buffer) {
434     c += "      int f_offset = f_base + kernel_index * "
435          "args.src_tensor.Slices() * " +
436          std::to_string(block_size.w) + ";\n";
437   } else {
438     c += "      int x_c = kernel_index * args.src_tensor.Slices();\n";
439   }
440   c += "      for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
441   const bool conditional_read = gpu_info.IsMali();
442   for (int z = 0; z < block_size.z; ++z) {
443     const std::string zind = std::to_string(z);
444     for (int y = 0; y < block_size.y; ++y) {
445       const std::string yind = std::to_string(y);
446       for (int x = 0; x < block_size.x; ++x) {
447         const std::string xind = std::to_string(x);
448         const std::string id = generate_id(xind, yind, zind);
449         std::string address;
450         if (src_def.IsLinear()) {
451           address = "addr" + id;
452         } else {
453           address = "sx" + xind + ", sy" + yind;
454           if (src_def.HasAxis(Axis::DEPTH)) {
455             address += ", sz" + zind;
456           }
457           address += ", s";
458         }
459         if (src_def.ReturnsZeroForNegOneRead()) {
460           c += "        FLT4 src" + id + " = args.src_tensor.Read(" + address +
461                "); " + address + " += ds" + id + ";\n";
462         } else {
463           const std::string check = generate_check(xind, yind, zind);
464           if (!check.empty()) {
465             if (conditional_read) {
466               c += "        FLT4 src" + id + " = " + check +
467                    " ? args.src_tensor.Read(" + address + ") : (FLT4)(0.0f);\n";
468             } else {
469               c += "        FLT4 src" + id + " = args.src_tensor.Read(" +
470                    address + ") * INIT_FLT(" + check + ");\n";
471             }
472           } else {
473             c += "        FLT4 src" + id + " = args.src_tensor.Read(" +
474                  address + ");\n";
475           }
476           if (src_def.IsLinear()) {
477             c += "        addr" + id + " += ds;\n";
478           }
479         }
480       }
481     }
482   }
483   if (weights_are_buffer) {
484     if (gpu_info.SupportsPointersInKernels()) {
485       c += "        __global FLT16* weights_cache = "
486            "args.weights.GetPtr(f_offset);\n";
487     }
488   } else {
489     for (int s = 0; s < block_size.w; ++s) {
490       c += absl::Substitute(
491           R"(        FLT4 f$1 = args.weights0.Read(dst_s + $0, x_c);
492         FLT4 f$2 = args.weights1.Read(dst_s + $0, x_c);
493         FLT4 f$3 = args.weights2.Read(dst_s + $0, x_c);
494         FLT4 f$4 = args.weights3.Read(dst_s + $0, x_c);
495 )",
496           s, s * 4 + 0, s * 4 + 1, s * 4 + 2, s * 4 + 3);
497     }
498     c += "        x_c++;\n";
499   }
500   for (int s = 0; s < block_size.w; ++s) {
501     if (weights_are_buffer && !gpu_info.SupportsPointersInKernels()) {
502       c += "        FLT16 flt16val = args.weights.Read(f_offset + " +
503            std::to_string(s) + ");\n";
504     }
505     const std::string sind = std::to_string(s);
506     for (int z = 0; z < block_size.z; ++z) {
507       const std::string zind = std::to_string(z);
508       for (int y = 0; y < block_size.y; ++y) {
509         const std::string yind = std::to_string(y);
510         for (int x = 0; x < block_size.x; ++x) {
511           const std::string xind = std::to_string(x);
512           const std::string id = generate_id(xind, yind, zind);
513           const std::string full_id = generate_id_full(xind, yind, zind, sind);
514           c += "        CONV" + sind + "(r" + full_id + ", src" + id + ");\n";
515         }
516       }
517     }
518   }
519   if (weights_are_buffer) {
520     c += "        f_offset += " + std::to_string(block_size.w) + ";\n";
521   }
522   c += "      }\n";
523   c += "    }\n";
524   c += "  }\n";
525   if (src_def.HasAxis(Axis::DEPTH)) {
526     c += "  }\n";
527   }
528   for (int s = 0; s < block_size.w; ++s) {
529     const std::string sind = std::to_string(s);
530     c += "  if (dst_s < args.dst_tensor.Slices()) {\n";
531     c += "    FLT4 bias_val = args.biases.Read(dst_s);\n";
532     for (int z = 0; z < block_size.z; ++z) {
533       const std::string zind = std::to_string(z);
534       for (int y = 0; y < block_size.y; ++y) {
535         const std::string yind = std::to_string(y);
536         for (int x = 0; x < block_size.x; ++x) {
537           const std::string xind = std::to_string(x);
538           const std::string id = generate_id_full(xind, yind, zind, sind);
539           std::string checks =
540               "xc < args.dst_tensor.Width() && yc < args.dst_tensor.Height()";
541           std::string coords = "xc, yc";
542           c += "    {\n";
543           c += "      int xc = dst_x + args.stride_x * " + xind + ";\n";
544           c += "      int yc = dst_y + args.stride_y * " + yind + ";\n";
545           if (src_def.HasAxis(Axis::DEPTH)) {
546             c += "      int zc = dst_z + args.stride_z * " + zind + ";\n";
547             checks += " && zc < args.dst_tensor.Depth()";
548             coords += ", zc";
549           }
550           c += "      if (" + checks + ") {\n";
551           c += "        FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
552           c += "        args.dst_tensor.Write(res, " + coords + ", dst_s);\n";
553           c += "      }\n";
554           c += "    }\n";
555         }
556       }
557     }
558     c += "  }\n";
559     c += "  dst_s++;\n";
560   }
561   c += "}\n";
562   return c;
563 }
564 
BindArguments(ArgumentsBinder * args)565 absl::Status ConvolutionTransposed::BindArguments(ArgumentsBinder* args) {
566   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
567     const int aligned_h =
568         AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
569     RETURN_IF_ERROR(
570         args->SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y)));
571   }
572   return absl::OkStatus();
573 }
574 
GetGridSize() const575 int3 ConvolutionTransposed::GetGridSize() const {
576   const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x);
577   const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
578   const int aligned_d = AlignByN(dst_[0]->Depth(), stride_.z * block_size_.z);
579   const int grid_x = DivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch();
580   const int grid_y = DivideRoundUp(aligned_h, block_size_.y) *
581                      DivideRoundUp(aligned_d, block_size_.z);
582   const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.w);
583   return int3(grid_x, grid_y, grid_z);
584 }
585 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const586 void ConvolutionTransposed::GetPossibleKernelWorkGroups(
587     TuningType tuning_type, const GpuInfo& gpu_info,
588     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
589   GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
590                             work_groups);
591 }
592 
CreateConvolutionTransposed(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)593 ConvolutionTransposed CreateConvolutionTransposed(
594     const GpuInfo& gpu_info, const OperationDef& definition,
595     const ConvolutionTransposedAttributes& attr) {
596   const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
597   ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
598   result.UploadWeights(attr.weights, weights_are_buffer);
599 
600   TensorLinearDescriptor desc;
601   desc.storage_type =
602       DeduceLinearStorageType(definition.GetPrimaryStorageType());
603   desc.element_type = definition.GetDataType();
604   desc.UploadLinearData(attr.bias);
605   result.args_.AddObject(
606       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
607   return result;
608 }
609 
CreateConvolutionTransposed3D(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr)610 ConvolutionTransposed CreateConvolutionTransposed3D(
611     const GpuInfo& gpu_info, const OperationDef& definition,
612     const ConvolutionTransposed3DAttributes& attr) {
613   const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
614   ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
615   result.UploadWeights(attr.weights, weights_are_buffer);
616 
617   TensorLinearDescriptor desc;
618   desc.storage_type =
619       DeduceLinearStorageType(definition.GetPrimaryStorageType());
620   desc.element_type = definition.GetDataType();
621   desc.UploadLinearData(attr.bias);
622   result.args_.AddObject(
623       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
624   return result;
625 }
626 
CreateConvolutionTransposedDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)627 ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
628     const GpuInfo& gpu_info, const OperationDef& definition,
629     const ConvolutionTransposedAttributes& attr) {
630   const bool weights_are_buffer = gpu_info.IsMali();
631   OperationDef new_def = definition;
632   new_def.src_tensors = {
633       definition.src_tensors[0]};  // leaving only src_tensor def, weights defs
634                                    // will be added later
635   const DataType weights_type = definition.GetDataType();
636   if (weights_are_buffer) {
637     // add 1 src_tensor(buffer) for weights
638     new_def.src_tensors.push_back(
639         {weights_type, TensorStorageType::BUFFER, Layout::HWC});
640   } else {
641     // add 4 src_tensors(4X textures 2d) for weights
642     new_def.src_tensors.push_back(
643         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
644     new_def.src_tensors.push_back(
645         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
646     new_def.src_tensors.push_back(
647         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
648     new_def.src_tensors.push_back(
649         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
650   }
651   ConvolutionTransposed result(new_def, attr, gpu_info, weights_are_buffer);
652 
653   TensorLinearDescriptor desc;
654   desc.storage_type = DeduceLinearStorageType(new_def.GetPrimaryStorageType());
655   desc.element_type = new_def.GetDataType();
656   desc.UploadLinearData(attr.bias);
657   result.args_.AddObject(
658       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
659   return result;
660 }
661 
662 }  // namespace gpu
663 }  // namespace tflite
664