• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/shape.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
28 
29 namespace tflite {
30 namespace gpu {
31 
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposedAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)32 ConvolutionTransposed::ConvolutionTransposed(
33     const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
34     const GpuInfo& gpu_info, bool weights_are_buffer)
35     : GPUOperation(definition),
36       stride_(attr.stride.w, attr.stride.h, 1, 1),
37       block_size_(2, 2, 1, 2) {
38   if (weights_are_buffer) {
39     if (gpu_info.IsApple()) {
40       weights_layout_ = WeightsLayout::kOHWIOGroupO4I4;
41     } else {
42       weights_layout_ = WeightsLayout::kOHWIOGroupI4O4;
43     }
44   } else {
45     if (gpu_info.IsApple()) {
46       weights_layout_ = WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4;
47     } else {
48       weights_layout_ = WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4;
49     }
50   }
51   const bool is_f16 = definition.precision == CalculationsPrecision::F16;
52   if (gpu_info.IsMali()) {
53     if (gpu_info.mali_info.IsMidgard()) {
54       block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
55     } else {
56       block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
57     }
58   }
59   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
60   if (dst_depth == 1 || dst_depth == 3) {
61     if (!gpu_info.IsMali()) {
62       block_size_.y *= block_size_.w;
63     }
64     block_size_.w = 1;
65   }
66 
67   args_.AddInt("stride_x", stride_.x);
68   args_.AddInt("stride_y", stride_.y);
69   args_.AddInt("padding_x", attr.padding.prepended.w);
70   args_.AddInt("padding_y", attr.padding.prepended.h);
71   args_.AddInt("kernel_size_x", attr.weights.shape.w);
72   args_.AddInt("kernel_size_y", attr.weights.shape.h);
73   code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
74                                             weights_are_buffer, block_size_);
75 }
76 
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)77 ConvolutionTransposed::ConvolutionTransposed(
78     const OperationDef& definition,
79     const ConvolutionTransposed3DAttributes& attr, const GpuInfo& gpu_info,
80     bool weights_are_buffer)
81     : GPUOperation(definition),
82       stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1),
83       block_size_(2, 2, 1, 2) {
84   if (weights_are_buffer) {
85     if (gpu_info.IsApple()) {
86       weights_layout_ = WeightsLayout::kOHWIOGroupO4I4;
87     } else {
88       weights_layout_ = WeightsLayout::kOHWIOGroupI4O4;
89     }
90   } else {
91     if (gpu_info.IsApple()) {
92       weights_layout_ = WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4;
93     } else {
94       weights_layout_ = WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4;
95     }
96   }
97   const bool is_f16 = definition.precision == CalculationsPrecision::F16;
98   if (gpu_info.IsMali()) {
99     if (gpu_info.mali_info.IsMidgard()) {
100       block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
101     } else {
102       block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
103     }
104   }
105   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
106   if (dst_depth == 1 || dst_depth == 3) {
107     if (!gpu_info.IsMali()) {
108       block_size_.y *= block_size_.w;
109     }
110     block_size_.w = 1;
111   }
112 
113   args_.AddInt("stride_x", stride_.x);
114   args_.AddInt("stride_y", stride_.y);
115   args_.AddInt("stride_z", stride_.z);
116   args_.AddInt("padding_x", attr.padding.prepended.w);
117   args_.AddInt("padding_y", attr.padding.prepended.h);
118   args_.AddInt("padding_z", attr.padding.prepended.d);
119   args_.AddInt("kernel_size_x", attr.weights.shape.w);
120   args_.AddInt("kernel_size_y", attr.weights.shape.h);
121   args_.AddInt("kernel_size_z", attr.weights.shape.d);
122   args_.AddInt("grid_size_y");
123   code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
124                                             weights_are_buffer, block_size_);
125 }
126 
GenerateConvolutionTransposedCode(const OperationDef & op_def,const GpuInfo & gpu_info,bool weights_are_buffer,const int4 & block_size)127 std::string ConvolutionTransposed::GenerateConvolutionTransposedCode(
128     const OperationDef& op_def, const GpuInfo& gpu_info,
129     bool weights_are_buffer, const int4& block_size) {
130   auto src_desc = op_def.src_tensors[0];
131   src_desc.SetAddressMode(AddressMode::kZero);
132   AddSrcTensor("src_tensor", src_desc);
133   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
134 
135   if (op_def.src_tensors.size() != 1) {
136     // dynamic weights
137     if (weights_layout_ == WeightsLayout::kOHWIOGroupI4O4 ||
138         weights_layout_ == WeightsLayout::kOHWIOGroupO4I4) {
139       BufferDescriptor desc;
140       desc.element_type = op_def.src_tensors[1].data_type;
141       desc.element_size = 16;
142       desc.memory_type = MemoryType::GLOBAL;
143       AddSrcBuffer("weights", desc);
144     } else {
145       for (int i = 0; i < 4; ++i) {
146         Texture2DDescriptor desc;
147         desc.element_type = op_def.src_tensors[1 + i].data_type;
148         const std::string name = "weights" + std::to_string(i);
149         AddSrcTexture2D("weights" + std::to_string(i), desc);
150       }
151     }
152   }
153 
154   const auto& src_def = op_def.src_tensors[0];
155 
156   std::string c;
157 
158   for (int s = 0; s < block_size.w; ++s) {
159     const std::string f0 = weights_are_buffer ? "FLT16_0123(weights_cache[" +
160                                                     std::to_string(s) + "])"
161                                               : "f" + std::to_string(s * 4 + 0);
162     const std::string f1 = weights_are_buffer ? "FLT16_4567(weights_cache[" +
163                                                     std::to_string(s) + "])"
164                                               : "f" + std::to_string(s * 4 + 1);
165     const std::string f2 = weights_are_buffer ? "FLT16_89ab(weights_cache[" +
166                                                     std::to_string(s) + "])"
167                                               : "f" + std::to_string(s * 4 + 2);
168     const std::string f3 = weights_are_buffer ? "FLT16_cdef(weights_cache[" +
169                                                     std::to_string(s) + "])"
170                                               : "f" + std::to_string(s * 4 + 3);
171     if (GetWeightsDescription().IsI4O4()) {
172       switch (op_def.precision) {
173         case CalculationsPrecision::F32:
174         case CalculationsPrecision::F16:
175           c += "#define CONV" + std::to_string(s) + "(R, S)    \\\n";
176           c += "R += S.x * " + f0 + "; \\\n";
177           c += "R += S.y * " + f1 + "; \\\n";
178           c += "R += S.z * " + f2 + "; \\\n";
179           c += "R += S.w * " + f3 + ";   \n";
180           break;
181         case CalculationsPrecision::F32_F16:
182           c += "#define CONV" + std::to_string(s) + "(R, S) \\\n";
183           c += "R += TO_ACCUM_TYPE(S.x * " + f0 + " + S.y * " + f1 +
184                " + S.z * " + f2 + " + S.w * " + f3 + ");\n";
185           break;
186       }
187     } else {
188       // O4I4
189       c += "#define CONV" + std::to_string(s) + "(R, S)    \\\n";
190       c += "R.x += dot(S, " + f0 + "); \\\n";
191       c += "R.y += dot(S, " + f1 + "); \\\n";
192       c += "R.z += dot(S, " + f2 + "); \\\n";
193       c += "R.w += dot(S, " + f3 + ");   \n";
194     }
195   }
196 
197   auto generate_id = [&](const std::string& x, const std::string& y,
198                          const std::string& z) {
199     std::string id;
200     if (src_def.HasAxis(Axis::WIDTH)) {
201       id += "_w" + x;
202     }
203     if (src_def.HasAxis(Axis::HEIGHT)) {
204       id += "_h" + y;
205     }
206     if (src_def.HasAxis(Axis::DEPTH)) {
207       id += "_d" + z;
208     }
209     return id;
210   };
211 
212   auto generate_id_full = [&](const std::string& x, const std::string& y,
213                               const std::string& z, const std::string& s) {
214     return generate_id(x, y, z) + "_s" + s;
215   };
216 
217   auto generate_check = [&](const std::string& x, const std::string& y,
218                             const std::string& z) {
219     std::string check;
220     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
221     const std::vector<std::string> names{"in_x", "in_y", "in_z"};
222     const std::vector<std::string> coords{x, y, z};
223     for (int i = 0; i < axes.size(); ++i) {
224       const auto& axis = axes[i];
225       if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) &&
226           block_size[i] != 1) {
227         if (!check.empty()) {
228           check += " && ";
229         }
230         check += names[i] + coords[i];
231       }
232     }
233     return check;
234   };
235 
236   switch (op_def.precision) {
237     case CalculationsPrecision::F32:
238       c += "#define FLT16 float16\n";
239       break;
240     case CalculationsPrecision::F32_F16:
241     case CalculationsPrecision::F16:
242       c += "#define FLT16 half16\n";
243       break;
244   }
245 
246   c += "MAIN_FUNCTION($0) {\n";
247   if (op_def.IsBatchSupported()) {
248     c += "  int linear_id = GLOBAL_ID_0;\n";
249     c += "  int dst_x = (linear_id / args.dst_tensor.Batch());\n";
250     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
251     c += "  args.dst_tensor.SetBatchRef(B);\n";
252     c += "  args.src_tensor.SetBatchRef(B);\n";
253   } else {
254     c += "  int dst_x = GLOBAL_ID_0;\n";
255   }
256   c += "  int rem_x = dst_x % args.stride_x;\n";
257   c += "  int ceil_x = dst_x / args.stride_x;\n";
258   c += "  dst_x = ceil_x * args.stride_x * " + std::to_string(block_size.x) +
259        " + rem_x;\n";
260   if (src_def.HasAxis(Axis::DEPTH)) {
261     c += "  int linear_id_y = GLOBAL_ID_1;\n";
262     c += "  int dst_y = linear_id_y % args.grid_size_y;\n";
263     c += "  int dst_z = linear_id_y / args.grid_size_y;\n";
264     c += "  int rem_z = dst_z % args.stride_z;\n";
265     c += "  int ceil_z = dst_z / args.stride_z;\n";
266     c += "  dst_z = ceil_z * args.stride_z * " + std::to_string(block_size.z) +
267          " + rem_z;\n";
268     c += "  if (dst_z >= args.dst_tensor.Depth()) return;\n";
269   } else {
270     c += "  int dst_y = GLOBAL_ID_1;\n";
271   }
272   c += "  int rem_y = dst_y % args.stride_y;\n";
273   c += "  int ceil_y = dst_y / args.stride_y;\n";
274   c += "  dst_y = ceil_y * args.stride_y * " + std::to_string(block_size.y) +
275        " + rem_y;\n";
276   c += "  int dst_s = GLOBAL_ID_2 * " + std::to_string(block_size.w) + ";\n";
277   c += "  if (dst_x >= args.dst_tensor.Width() || dst_y >= "
278        "args.dst_tensor.Height() || dst_s >= "
279        "args.dst_tensor.Slices()) return;\n";
280   if (weights_are_buffer) {
281     c += "  int f_base = dst_s * args.src_tensor.Slices() * args.kernel_size_x "
282          "* args.kernel_size_y";
283     if (src_def.HasAxis(Axis::DEPTH)) {
284       c += " * args.kernel_size_z";
285     }
286     c += ";\n";
287   }
288   for (int s = 0; s < block_size.w; ++s) {
289     const std::string sind = std::to_string(s);
290     for (int z = 0; z < block_size.z; ++z) {
291       const std::string zind = std::to_string(z);
292       for (int y = 0; y < block_size.y; ++y) {
293         const std::string yind = std::to_string(y);
294         for (int x = 0; x < block_size.x; ++x) {
295           const std::string xind = std::to_string(x);
296           c += "  ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
297                " = INIT_ACCUM_FLT4(0.0f);\n";
298         }
299       }
300     }
301   }
302   c += "  int kernel_first_dst_x = dst_x + args.padding_x;\n";
303   c += "  int kernel_first_dst_y = dst_y + args.padding_y;\n";
304   c += "  int kernel_last_dst_x = kernel_first_dst_x - args.kernel_size_x;\n";
305   c += "  int kernel_last_dst_y = kernel_first_dst_y - args.kernel_size_y;\n";
306   c += "  int offset_x = abs(args.padding_x);\n";
307   c += "  int offset_x_strided = offset_x * args.stride_x;\n";
308   c +=
309       "  int src_x = (kernel_first_dst_x + offset_x_strided) / args.stride_x - "
310       "offset_x;\n";
311   c += "  int offset_y = abs(args.padding_y);\n";
312   c += "  int offset_y_strided = offset_y * args.stride_y;\n";
313   c +=
314       "  int src_y = (kernel_first_dst_y + offset_y_strided) / args.stride_y - "
315       "offset_y;\n";
316   if (src_def.HasAxis(Axis::DEPTH)) {
317     c += "  int kernel_first_dst_z = dst_z + args.padding_z;\n";
318     c += "  int kernel_last_dst_z = kernel_first_dst_z - args.kernel_size_z;\n";
319     c += "  int offset_z = abs(args.padding_z);\n";
320     c += "  int offset_z_strided = offset_z * args.stride_z;\n";
321     c += "  int src_z = (kernel_first_dst_z + offset_z_strided) / "
322          "args.stride_z - offset_z;\n";
323     c += "  int src_as_dst_z = src_z * args.stride_z;\n";
324     c +=
325         "  for (;src_as_dst_z > kernel_last_dst_z; src_z -= 1, src_as_dst_z -= "
326         "args.stride_z) {\n";
327     for (int z = 0; z < block_size.z; ++z) {
328       const std::string zindex = std::to_string(z);
329       c += "    int sz" + zindex + " = src_z + " + zindex + ";\n";
330       if (!src_def.SupportsZeroClamp(Axis::DEPTH)) {
331         c += "    bool in_z" + zindex + " = sz" + zindex + " >= 0 && sz" +
332              zindex + " < args.src_tensor.Depth();\n";
333         if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) {
334           c += "    sz" + zindex + " = clamp(sz" + zindex +
335                ", 0, args.src_tensor.Depth() - 1);\n";
336         }
337       }
338     }
339     if (block_size.z == 1 && !src_def.SupportsZeroClamp(Axis::DEPTH)) {
340       c += "    if (!in_z0) continue;\n";
341     }
342     c += "    int kernel_z = kernel_first_dst_z - src_as_dst_z;\n";
343     c += "    int src_as_dst_y = src_y * args.stride_y;\n";
344     c += "    int src_y_copy = src_y;\n";
345     c += "    for (;src_as_dst_y > kernel_last_dst_y; src_y_copy -= 1, "
346          "src_as_dst_y -= args.stride_y) {\n";
347   } else {
348     c += "  int src_as_dst_y = src_y * args.stride_y;\n";
349     c += "  for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y "
350          "-= args.stride_y) {\n";
351   }
352   for (int y = 0; y < block_size.y; ++y) {
353     const std::string yindex = std::to_string(y);
354     const std::string src_y =
355         src_def.HasAxis(Axis::DEPTH) ? "src_y_copy" : "src_y";
356     c += "    int sy" + yindex + " = " + src_y + " + " + yindex + ";\n";
357     if (!src_def.SupportsZeroClamp(Axis::HEIGHT)) {
358       c += "    bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" +
359            yindex + " < args.src_tensor.Height();\n";
360       if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) {
361         c += "    sy" + yindex + " = clamp(sy" + yindex +
362              ", 0, args.src_tensor.Height() - 1);\n";
363       }
364     }
365   }
366   if (block_size.y == 1 && !src_def.SupportsZeroClamp(Axis::HEIGHT)) {
367     c += "      if (!in_y0) continue;\n";
368   }
369   c += "    int kernel_y = kernel_first_dst_y - src_as_dst_y;\n";
370   c += "    int src_as_dst_x = src_x * args.stride_x;\n";
371   c += "    int src_x_copy = src_x;\n";
372   c += "    for (;src_as_dst_x > kernel_last_dst_x; src_x_copy -= 1, "
373        "src_as_dst_x "
374        "-= args.stride_x) {\n";
375   for (int x = 0; x < block_size.x; ++x) {
376     const std::string xindex = std::to_string(x);
377     c += "      int sx" + xindex + " = src_x_copy + " + xindex + ";\n";
378     if (!src_def.SupportsZeroClamp(Axis::WIDTH)) {
379       c += "      bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" +
380            xindex + " < args.src_tensor.Width();\n";
381       if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) {
382         c += "      sx" + xindex + " = clamp(sx" + xindex +
383              ", 0, args.src_tensor.Width() - 1);\n";
384       }
385     }
386   }
387   if (block_size.x == 1 && !src_def.SupportsZeroClamp(Axis::WIDTH)) {
388     c += "      if (!in_x0) continue;\n";
389   }
390   for (int z = 0; z < block_size.z; ++z) {
391     const std::string zind = std::to_string(z);
392     for (int y = 0; y < block_size.y; ++y) {
393       const std::string yind = std::to_string(y);
394       for (int x = 0; x < block_size.x; ++x) {
395         const std::string xind = std::to_string(x);
396         const std::string id = generate_id(xind, yind, zind);
397         const std::string check = generate_check(xind, yind, zind);
398         std::string coords = "sx" + xind + ", sy" + yind;
399         if (src_def.HasAxis(Axis::DEPTH)) {
400           coords += ", sz" + zind;
401         }
402         if (src_def.IsLinear()) {
403           c += "      args.src_tensor.GetAddress(addr" + id + ", " + coords +
404                ", 0);\n";
405         }
406         if (src_def.ReturnsZeroForNegOneRead()) {
407           c += "      addr" + id + " = select(-1, addr" + id + ", (" + check +
408                "));\n";
409           c += "      int ds" + id +
410                " = select(0, args.src_tensor.SliceStride(), (" + check +
411                "));\n";
412         }
413       }
414     }
415   }
416   if (src_def.storage_type == TensorStorageType::BUFFER) {
417     c += "      int ds = args.src_tensor.SliceStride();\n";
418   }
419   c += "      int kernel_x = kernel_first_dst_x - src_as_dst_x;\n";
420   if (src_def.HasAxis(Axis::DEPTH)) {
421     c += "      int kernel_index = (kernel_z * args.kernel_size_y + kernel_y) "
422          "*  args.kernel_size_x + kernel_x;\n";
423   } else {
424     c += "      int kernel_index = kernel_y * args.kernel_size_x + kernel_x;\n";
425   }
426   if (weights_are_buffer) {
427     c += "      int f_offset = f_base + kernel_index * "
428          "args.src_tensor.Slices() * " +
429          std::to_string(block_size.w) + ";\n";
430   } else {
431     c += "      int x_c = kernel_index * args.src_tensor.Slices();\n";
432   }
433   c += "      for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
434   const bool conditional_read = gpu_info.IsMali();
435   for (int z = 0; z < block_size.z; ++z) {
436     const std::string zind = std::to_string(z);
437     for (int y = 0; y < block_size.y; ++y) {
438       const std::string yind = std::to_string(y);
439       for (int x = 0; x < block_size.x; ++x) {
440         const std::string xind = std::to_string(x);
441         const std::string id = generate_id(xind, yind, zind);
442         std::string address;
443         if (src_def.IsLinear()) {
444           address = "addr" + id;
445         } else {
446           address = "sx" + xind + ", sy" + yind;
447           if (src_def.HasAxis(Axis::DEPTH)) {
448             address += ", sz" + zind;
449           }
450           address += ", s";
451         }
452         if (src_def.ReturnsZeroForNegOneRead()) {
453           c += "        FLT4 src" + id + " = args.src_tensor.Read(" + address +
454                "); " + address + " += ds" + id + ";\n";
455         } else {
456           const std::string check = generate_check(xind, yind, zind);
457           if (!check.empty()) {
458             if (conditional_read) {
459               c += "        FLT4 src" + id + " = " + check +
460                    " ? args.src_tensor.Read(" + address + ") : (FLT4)(0.0f);\n";
461             } else {
462               c += "        FLT4 src" + id + " = args.src_tensor.Read(" +
463                    address + ") * INIT_FLT(" + check + ");\n";
464             }
465           } else {
466             c += "        FLT4 src" + id + " = args.src_tensor.Read(" +
467                  address + ");\n";
468           }
469           if (src_def.IsLinear()) {
470             c += "        addr" + id + " += ds;\n";
471           }
472         }
473       }
474     }
475   }
476   if (weights_are_buffer) {
477     c += "        __global FLT16* weights_cache = "
478          "args.weights.GetPtr(f_offset);\n";
479     c += "        f_offset += " + std::to_string(block_size.w) + ";\n";
480   } else {
481     for (int s = 0; s < block_size.w; ++s) {
482       c += absl::Substitute(
483           R"(        FLT4 f$1 = args.weights0.Read(dst_s + $0, x_c);
484         FLT4 f$2 = args.weights1.Read(dst_s + $0, x_c);
485         FLT4 f$3 = args.weights2.Read(dst_s + $0, x_c);
486         FLT4 f$4 = args.weights3.Read(dst_s + $0, x_c);
487 )",
488           s, s * 4 + 0, s * 4 + 1, s * 4 + 2, s * 4 + 3);
489     }
490     c += "        x_c++;\n";
491   }
492   for (int s = 0; s < block_size.w; ++s) {
493     const std::string sind = std::to_string(s);
494     for (int z = 0; z < block_size.z; ++z) {
495       const std::string zind = std::to_string(z);
496       for (int y = 0; y < block_size.y; ++y) {
497         const std::string yind = std::to_string(y);
498         for (int x = 0; x < block_size.x; ++x) {
499           const std::string xind = std::to_string(x);
500           const std::string id = generate_id(xind, yind, zind);
501           const std::string full_id = generate_id_full(xind, yind, zind, sind);
502           c += "        CONV" + sind + "(r" + full_id + ", src" + id + ");\n";
503         }
504       }
505     }
506   }
507   c += "      }\n";
508   c += "    }\n";
509   c += "  }\n";
510   if (src_def.HasAxis(Axis::DEPTH)) {
511     c += "  }\n";
512   }
513   for (int s = 0; s < block_size.w; ++s) {
514     const std::string sind = std::to_string(s);
515     c += "  if (dst_s < args.dst_tensor.Slices()) {\n";
516     c += "    FLT4 bias_val = args.biases.Read(dst_s);\n";
517     for (int z = 0; z < block_size.z; ++z) {
518       const std::string zind = std::to_string(z);
519       for (int y = 0; y < block_size.y; ++y) {
520         const std::string yind = std::to_string(y);
521         for (int x = 0; x < block_size.x; ++x) {
522           const std::string xind = std::to_string(x);
523           const std::string id = generate_id_full(xind, yind, zind, sind);
524           std::string checks =
525               "xc < args.dst_tensor.Width() && yc < args.dst_tensor.Height()";
526           std::string coords = "xc, yc";
527           c += "    {\n";
528           c += "      int xc = dst_x + args.stride_x * " + xind + ";\n";
529           c += "      int yc = dst_y + args.stride_y * " + yind + ";\n";
530           if (src_def.HasAxis(Axis::DEPTH)) {
531             c += "      int zc = dst_z + args.stride_z * " + zind + ";\n";
532             checks += " && zc < args.dst_tensor.Depth()";
533             coords += ", zc";
534           }
535           c += "      if (" + checks + ") {\n";
536           c += "        FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
537           c += "        args.dst_tensor.Write(res, " + coords + ", dst_s);\n";
538           c += "      }\n";
539           c += "    }\n";
540         }
541       }
542     }
543     c += "  }\n";
544     c += "  dst_s++;\n";
545   }
546   c += "}\n";
547   return c;
548 }
549 
BindArguments(ArgumentsBinder * args)550 absl::Status ConvolutionTransposed::BindArguments(ArgumentsBinder* args) {
551   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
552     const int aligned_h =
553         AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
554     RETURN_IF_ERROR(
555         args->SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y)));
556   }
557   return absl::OkStatus();
558 }
559 
GetGridSize() const560 int3 ConvolutionTransposed::GetGridSize() const {
561   const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x);
562   const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
563   const int aligned_d = AlignByN(dst_[0]->Depth(), stride_.z * block_size_.z);
564   const int grid_x = DivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch();
565   const int grid_y = DivideRoundUp(aligned_h, block_size_.y) *
566                      DivideRoundUp(aligned_d, block_size_.z);
567   const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.w);
568   return int3(grid_x, grid_y, grid_z);
569 }
570 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const571 void ConvolutionTransposed::GetPossibleKernelWorkGroups(
572     TuningType tuning_type, const GpuInfo& gpu_info,
573     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
574   GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
575                             work_groups);
576 }
577 
CreateConvolutionTransposed(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)578 ConvolutionTransposed CreateConvolutionTransposed(
579     const GpuInfo& gpu_info, const OperationDef& definition,
580     const ConvolutionTransposedAttributes& attr) {
581   const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
582   ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
583   result.UploadWeights(attr.weights, weights_are_buffer);
584 
585   TensorLinearDescriptor desc;
586   desc.storage_type =
587       DeduceLinearStorageType(definition.GetPrimaryStorageType());
588   desc.element_type = definition.GetDataType();
589   desc.UploadLinearData(attr.bias);
590   result.args_.AddObject(
591       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
592   return result;
593 }
594 
CreateConvolutionTransposed3D(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr)595 ConvolutionTransposed CreateConvolutionTransposed3D(
596     const GpuInfo& gpu_info, const OperationDef& definition,
597     const ConvolutionTransposed3DAttributes& attr) {
598   const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
599   ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
600   result.UploadWeights(attr.weights, weights_are_buffer);
601 
602   TensorLinearDescriptor desc;
603   desc.storage_type =
604       DeduceLinearStorageType(definition.GetPrimaryStorageType());
605   desc.element_type = definition.GetDataType();
606   desc.UploadLinearData(attr.bias);
607   result.args_.AddObject(
608       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
609   return result;
610 }
611 
CreateConvolutionTransposedDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)612 ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
613     const GpuInfo& gpu_info, const OperationDef& definition,
614     const ConvolutionTransposedAttributes& attr) {
615   const bool weights_are_buffer = gpu_info.IsMali();
616   OperationDef new_def = definition;
617   new_def.src_tensors = {
618       definition.src_tensors[0]};  // leaving only src_tensor def, weights defs
619                                    // will be added later
620   const DataType weights_type = definition.GetDataType();
621   if (weights_are_buffer) {
622     // add 1 src_tensor(buffer) for weights
623     new_def.src_tensors.push_back(
624         {weights_type, TensorStorageType::BUFFER, Layout::HWC});
625   } else {
626     // add 4 src_tensors(4X textures 2d) for weights
627     new_def.src_tensors.push_back(
628         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
629     new_def.src_tensors.push_back(
630         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
631     new_def.src_tensors.push_back(
632         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
633     new_def.src_tensors.push_back(
634         {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
635   }
636   ConvolutionTransposed result(new_def, attr, gpu_info, weights_are_buffer);
637 
638   TensorLinearDescriptor desc;
639   desc.storage_type = DeduceLinearStorageType(new_def.GetPrimaryStorageType());
640   desc.element_type = new_def.GetDataType();
641   desc.UploadLinearData(attr.bias);
642   result.args_.AddObject(
643       "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
644   return result;
645 }
646 
647 }  // namespace gpu
648 }  // namespace tflite
649