1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/convolution_transposed.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/lite/delegates/gpu/common/shape.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/weights_layout.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
28
29 namespace tflite {
30 namespace gpu {
31
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposedAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)32 ConvolutionTransposed::ConvolutionTransposed(
33 const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
34 const GpuInfo& gpu_info, bool weights_are_buffer)
35 : GPUOperation(definition),
36 stride_(attr.stride.w, attr.stride.h, 1, 1),
37 block_size_(2, 2, 1, 2) {
38 if (weights_are_buffer) {
39 if (gpu_info.IsApple()) {
40 weights_layout_ = WeightsLayout::kOHWIOGroupO4I4;
41 } else {
42 weights_layout_ = WeightsLayout::kOHWIOGroupI4O4;
43 }
44 } else {
45 if (gpu_info.IsApple()) {
46 weights_layout_ = WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4;
47 } else {
48 weights_layout_ = WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4;
49 }
50 }
51 const bool is_f16 = definition.precision == CalculationsPrecision::F16;
52 if (gpu_info.IsMali()) {
53 if (gpu_info.mali_info.IsMidgard()) {
54 block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
55 } else {
56 block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
57 }
58 }
59 const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
60 if (dst_depth == 1 || dst_depth == 3) {
61 if (!gpu_info.IsMali()) {
62 block_size_.y *= block_size_.w;
63 }
64 block_size_.w = 1;
65 }
66
67 args_.AddInt("stride_x", stride_.x);
68 args_.AddInt("stride_y", stride_.y);
69 args_.AddInt("padding_x", attr.padding.prepended.w);
70 args_.AddInt("padding_y", attr.padding.prepended.h);
71 args_.AddInt("kernel_size_x", attr.weights.shape.w);
72 args_.AddInt("kernel_size_y", attr.weights.shape.h);
73 code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
74 weights_are_buffer, block_size_);
75 }
76
ConvolutionTransposed(const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr,const GpuInfo & gpu_info,bool weights_are_buffer)77 ConvolutionTransposed::ConvolutionTransposed(
78 const OperationDef& definition,
79 const ConvolutionTransposed3DAttributes& attr, const GpuInfo& gpu_info,
80 bool weights_are_buffer)
81 : GPUOperation(definition),
82 stride_(attr.stride.w, attr.stride.h, attr.stride.d, 1),
83 block_size_(2, 2, 1, 2) {
84 if (weights_are_buffer) {
85 if (gpu_info.IsApple()) {
86 weights_layout_ = WeightsLayout::kOHWIOGroupO4I4;
87 } else {
88 weights_layout_ = WeightsLayout::kOHWIOGroupI4O4;
89 }
90 } else {
91 if (gpu_info.IsApple()) {
92 weights_layout_ = WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4;
93 } else {
94 weights_layout_ = WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4;
95 }
96 }
97 const bool is_f16 = definition.precision == CalculationsPrecision::F16;
98 if (gpu_info.IsMali()) {
99 if (gpu_info.mali_info.IsMidgard()) {
100 block_size_ = is_f16 ? int4(2, 1, 1, 2) : int4(2, 1, 1, 1);
101 } else {
102 block_size_ = is_f16 ? int4(2, 2, 1, 2) : int4(2, 2, 1, 1);
103 }
104 }
105 const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
106 if (dst_depth == 1 || dst_depth == 3) {
107 if (!gpu_info.IsMali()) {
108 block_size_.y *= block_size_.w;
109 }
110 block_size_.w = 1;
111 }
112
113 args_.AddInt("stride_x", stride_.x);
114 args_.AddInt("stride_y", stride_.y);
115 args_.AddInt("stride_z", stride_.z);
116 args_.AddInt("padding_x", attr.padding.prepended.w);
117 args_.AddInt("padding_y", attr.padding.prepended.h);
118 args_.AddInt("padding_z", attr.padding.prepended.d);
119 args_.AddInt("kernel_size_x", attr.weights.shape.w);
120 args_.AddInt("kernel_size_y", attr.weights.shape.h);
121 args_.AddInt("kernel_size_z", attr.weights.shape.d);
122 args_.AddInt("grid_size_y");
123 code_ = GenerateConvolutionTransposedCode(definition_, gpu_info,
124 weights_are_buffer, block_size_);
125 }
126
GenerateConvolutionTransposedCode(const OperationDef & op_def,const GpuInfo & gpu_info,bool weights_are_buffer,const int4 & block_size)127 std::string ConvolutionTransposed::GenerateConvolutionTransposedCode(
128 const OperationDef& op_def, const GpuInfo& gpu_info,
129 bool weights_are_buffer, const int4& block_size) {
130 auto src_desc = op_def.src_tensors[0];
131 src_desc.SetAddressMode(AddressMode::kZero);
132 AddSrcTensor("src_tensor", src_desc);
133 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
134
135 if (op_def.src_tensors.size() != 1) {
136 // dynamic weights
137 if (weights_layout_ == WeightsLayout::kOHWIOGroupI4O4 ||
138 weights_layout_ == WeightsLayout::kOHWIOGroupO4I4) {
139 BufferDescriptor desc;
140 desc.element_type = op_def.src_tensors[1].data_type;
141 desc.element_size = 16;
142 desc.memory_type = MemoryType::GLOBAL;
143 AddSrcBuffer("weights", desc);
144 } else {
145 for (int i = 0; i < 4; ++i) {
146 Texture2DDescriptor desc;
147 desc.element_type = op_def.src_tensors[1 + i].data_type;
148 const std::string name = "weights" + std::to_string(i);
149 AddSrcTexture2D("weights" + std::to_string(i), desc);
150 }
151 }
152 }
153
154 const auto& src_def = op_def.src_tensors[0];
155
156 std::string c;
157
158 for (int s = 0; s < block_size.w; ++s) {
159 const std::string f0 = weights_are_buffer ? "FLT16_0123(weights_cache[" +
160 std::to_string(s) + "])"
161 : "f" + std::to_string(s * 4 + 0);
162 const std::string f1 = weights_are_buffer ? "FLT16_4567(weights_cache[" +
163 std::to_string(s) + "])"
164 : "f" + std::to_string(s * 4 + 1);
165 const std::string f2 = weights_are_buffer ? "FLT16_89ab(weights_cache[" +
166 std::to_string(s) + "])"
167 : "f" + std::to_string(s * 4 + 2);
168 const std::string f3 = weights_are_buffer ? "FLT16_cdef(weights_cache[" +
169 std::to_string(s) + "])"
170 : "f" + std::to_string(s * 4 + 3);
171 if (GetWeightsDescription().IsI4O4()) {
172 switch (op_def.precision) {
173 case CalculationsPrecision::F32:
174 case CalculationsPrecision::F16:
175 c += "#define CONV" + std::to_string(s) + "(R, S) \\\n";
176 c += "R += S.x * " + f0 + "; \\\n";
177 c += "R += S.y * " + f1 + "; \\\n";
178 c += "R += S.z * " + f2 + "; \\\n";
179 c += "R += S.w * " + f3 + "; \n";
180 break;
181 case CalculationsPrecision::F32_F16:
182 c += "#define CONV" + std::to_string(s) + "(R, S) \\\n";
183 c += "R += TO_ACCUM_TYPE(S.x * " + f0 + " + S.y * " + f1 +
184 " + S.z * " + f2 + " + S.w * " + f3 + ");\n";
185 break;
186 }
187 } else {
188 // O4I4
189 c += "#define CONV" + std::to_string(s) + "(R, S) \\\n";
190 c += "R.x += dot(S, " + f0 + "); \\\n";
191 c += "R.y += dot(S, " + f1 + "); \\\n";
192 c += "R.z += dot(S, " + f2 + "); \\\n";
193 c += "R.w += dot(S, " + f3 + "); \n";
194 }
195 }
196
197 auto generate_id = [&](const std::string& x, const std::string& y,
198 const std::string& z) {
199 std::string id;
200 if (src_def.HasAxis(Axis::WIDTH)) {
201 id += "_w" + x;
202 }
203 if (src_def.HasAxis(Axis::HEIGHT)) {
204 id += "_h" + y;
205 }
206 if (src_def.HasAxis(Axis::DEPTH)) {
207 id += "_d" + z;
208 }
209 return id;
210 };
211
212 auto generate_id_full = [&](const std::string& x, const std::string& y,
213 const std::string& z, const std::string& s) {
214 return generate_id(x, y, z) + "_s" + s;
215 };
216
217 auto generate_check = [&](const std::string& x, const std::string& y,
218 const std::string& z) {
219 std::string check;
220 const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
221 const std::vector<std::string> names{"in_x", "in_y", "in_z"};
222 const std::vector<std::string> coords{x, y, z};
223 for (int i = 0; i < axes.size(); ++i) {
224 const auto& axis = axes[i];
225 if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis) &&
226 block_size[i] != 1) {
227 if (!check.empty()) {
228 check += " && ";
229 }
230 check += names[i] + coords[i];
231 }
232 }
233 return check;
234 };
235
236 switch (op_def.precision) {
237 case CalculationsPrecision::F32:
238 c += "#define FLT16 float16\n";
239 break;
240 case CalculationsPrecision::F32_F16:
241 case CalculationsPrecision::F16:
242 c += "#define FLT16 half16\n";
243 break;
244 }
245
246 c += "MAIN_FUNCTION($0) {\n";
247 if (op_def.IsBatchSupported()) {
248 c += " int linear_id = GLOBAL_ID_0;\n";
249 c += " int dst_x = (linear_id / args.dst_tensor.Batch());\n";
250 c += " int B = linear_id % args.dst_tensor.Batch();\n";
251 c += " args.dst_tensor.SetBatchRef(B);\n";
252 c += " args.src_tensor.SetBatchRef(B);\n";
253 } else {
254 c += " int dst_x = GLOBAL_ID_0;\n";
255 }
256 c += " int rem_x = dst_x % args.stride_x;\n";
257 c += " int ceil_x = dst_x / args.stride_x;\n";
258 c += " dst_x = ceil_x * args.stride_x * " + std::to_string(block_size.x) +
259 " + rem_x;\n";
260 if (src_def.HasAxis(Axis::DEPTH)) {
261 c += " int linear_id_y = GLOBAL_ID_1;\n";
262 c += " int dst_y = linear_id_y % args.grid_size_y;\n";
263 c += " int dst_z = linear_id_y / args.grid_size_y;\n";
264 c += " int rem_z = dst_z % args.stride_z;\n";
265 c += " int ceil_z = dst_z / args.stride_z;\n";
266 c += " dst_z = ceil_z * args.stride_z * " + std::to_string(block_size.z) +
267 " + rem_z;\n";
268 c += " if (dst_z >= args.dst_tensor.Depth()) return;\n";
269 } else {
270 c += " int dst_y = GLOBAL_ID_1;\n";
271 }
272 c += " int rem_y = dst_y % args.stride_y;\n";
273 c += " int ceil_y = dst_y / args.stride_y;\n";
274 c += " dst_y = ceil_y * args.stride_y * " + std::to_string(block_size.y) +
275 " + rem_y;\n";
276 c += " int dst_s = GLOBAL_ID_2 * " + std::to_string(block_size.w) + ";\n";
277 c += " if (dst_x >= args.dst_tensor.Width() || dst_y >= "
278 "args.dst_tensor.Height() || dst_s >= "
279 "args.dst_tensor.Slices()) return;\n";
280 if (weights_are_buffer) {
281 c += " int f_base = dst_s * args.src_tensor.Slices() * args.kernel_size_x "
282 "* args.kernel_size_y";
283 if (src_def.HasAxis(Axis::DEPTH)) {
284 c += " * args.kernel_size_z";
285 }
286 c += ";\n";
287 }
288 for (int s = 0; s < block_size.w; ++s) {
289 const std::string sind = std::to_string(s);
290 for (int z = 0; z < block_size.z; ++z) {
291 const std::string zind = std::to_string(z);
292 for (int y = 0; y < block_size.y; ++y) {
293 const std::string yind = std::to_string(y);
294 for (int x = 0; x < block_size.x; ++x) {
295 const std::string xind = std::to_string(x);
296 c += " ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
297 " = INIT_ACCUM_FLT4(0.0f);\n";
298 }
299 }
300 }
301 }
302 c += " int kernel_first_dst_x = dst_x + args.padding_x;\n";
303 c += " int kernel_first_dst_y = dst_y + args.padding_y;\n";
304 c += " int kernel_last_dst_x = kernel_first_dst_x - args.kernel_size_x;\n";
305 c += " int kernel_last_dst_y = kernel_first_dst_y - args.kernel_size_y;\n";
306 c += " int offset_x = abs(args.padding_x);\n";
307 c += " int offset_x_strided = offset_x * args.stride_x;\n";
308 c +=
309 " int src_x = (kernel_first_dst_x + offset_x_strided) / args.stride_x - "
310 "offset_x;\n";
311 c += " int offset_y = abs(args.padding_y);\n";
312 c += " int offset_y_strided = offset_y * args.stride_y;\n";
313 c +=
314 " int src_y = (kernel_first_dst_y + offset_y_strided) / args.stride_y - "
315 "offset_y;\n";
316 if (src_def.HasAxis(Axis::DEPTH)) {
317 c += " int kernel_first_dst_z = dst_z + args.padding_z;\n";
318 c += " int kernel_last_dst_z = kernel_first_dst_z - args.kernel_size_z;\n";
319 c += " int offset_z = abs(args.padding_z);\n";
320 c += " int offset_z_strided = offset_z * args.stride_z;\n";
321 c += " int src_z = (kernel_first_dst_z + offset_z_strided) / "
322 "args.stride_z - offset_z;\n";
323 c += " int src_as_dst_z = src_z * args.stride_z;\n";
324 c +=
325 " for (;src_as_dst_z > kernel_last_dst_z; src_z -= 1, src_as_dst_z -= "
326 "args.stride_z) {\n";
327 for (int z = 0; z < block_size.z; ++z) {
328 const std::string zindex = std::to_string(z);
329 c += " int sz" + zindex + " = src_z + " + zindex + ";\n";
330 if (!src_def.SupportsZeroClamp(Axis::DEPTH)) {
331 c += " bool in_z" + zindex + " = sz" + zindex + " >= 0 && sz" +
332 zindex + " < args.src_tensor.Depth();\n";
333 if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) {
334 c += " sz" + zindex + " = clamp(sz" + zindex +
335 ", 0, args.src_tensor.Depth() - 1);\n";
336 }
337 }
338 }
339 if (block_size.z == 1 && !src_def.SupportsZeroClamp(Axis::DEPTH)) {
340 c += " if (!in_z0) continue;\n";
341 }
342 c += " int kernel_z = kernel_first_dst_z - src_as_dst_z;\n";
343 c += " int src_as_dst_y = src_y * args.stride_y;\n";
344 c += " int src_y_copy = src_y;\n";
345 c += " for (;src_as_dst_y > kernel_last_dst_y; src_y_copy -= 1, "
346 "src_as_dst_y -= args.stride_y) {\n";
347 } else {
348 c += " int src_as_dst_y = src_y * args.stride_y;\n";
349 c += " for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y "
350 "-= args.stride_y) {\n";
351 }
352 for (int y = 0; y < block_size.y; ++y) {
353 const std::string yindex = std::to_string(y);
354 const std::string src_y =
355 src_def.HasAxis(Axis::DEPTH) ? "src_y_copy" : "src_y";
356 c += " int sy" + yindex + " = " + src_y + " + " + yindex + ";\n";
357 if (!src_def.SupportsZeroClamp(Axis::HEIGHT)) {
358 c += " bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" +
359 yindex + " < args.src_tensor.Height();\n";
360 if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) {
361 c += " sy" + yindex + " = clamp(sy" + yindex +
362 ", 0, args.src_tensor.Height() - 1);\n";
363 }
364 }
365 }
366 if (block_size.y == 1 && !src_def.SupportsZeroClamp(Axis::HEIGHT)) {
367 c += " if (!in_y0) continue;\n";
368 }
369 c += " int kernel_y = kernel_first_dst_y - src_as_dst_y;\n";
370 c += " int src_as_dst_x = src_x * args.stride_x;\n";
371 c += " int src_x_copy = src_x;\n";
372 c += " for (;src_as_dst_x > kernel_last_dst_x; src_x_copy -= 1, "
373 "src_as_dst_x "
374 "-= args.stride_x) {\n";
375 for (int x = 0; x < block_size.x; ++x) {
376 const std::string xindex = std::to_string(x);
377 c += " int sx" + xindex + " = src_x_copy + " + xindex + ";\n";
378 if (!src_def.SupportsZeroClamp(Axis::WIDTH)) {
379 c += " bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" +
380 xindex + " < args.src_tensor.Width();\n";
381 if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) {
382 c += " sx" + xindex + " = clamp(sx" + xindex +
383 ", 0, args.src_tensor.Width() - 1);\n";
384 }
385 }
386 }
387 if (block_size.x == 1 && !src_def.SupportsZeroClamp(Axis::WIDTH)) {
388 c += " if (!in_x0) continue;\n";
389 }
390 for (int z = 0; z < block_size.z; ++z) {
391 const std::string zind = std::to_string(z);
392 for (int y = 0; y < block_size.y; ++y) {
393 const std::string yind = std::to_string(y);
394 for (int x = 0; x < block_size.x; ++x) {
395 const std::string xind = std::to_string(x);
396 const std::string id = generate_id(xind, yind, zind);
397 const std::string check = generate_check(xind, yind, zind);
398 std::string coords = "sx" + xind + ", sy" + yind;
399 if (src_def.HasAxis(Axis::DEPTH)) {
400 coords += ", sz" + zind;
401 }
402 if (src_def.IsLinear()) {
403 c += " args.src_tensor.GetAddress(addr" + id + ", " + coords +
404 ", 0);\n";
405 }
406 if (src_def.ReturnsZeroForNegOneRead()) {
407 c += " addr" + id + " = select(-1, addr" + id + ", (" + check +
408 "));\n";
409 c += " int ds" + id +
410 " = select(0, args.src_tensor.SliceStride(), (" + check +
411 "));\n";
412 }
413 }
414 }
415 }
416 if (src_def.storage_type == TensorStorageType::BUFFER) {
417 c += " int ds = args.src_tensor.SliceStride();\n";
418 }
419 c += " int kernel_x = kernel_first_dst_x - src_as_dst_x;\n";
420 if (src_def.HasAxis(Axis::DEPTH)) {
421 c += " int kernel_index = (kernel_z * args.kernel_size_y + kernel_y) "
422 "* args.kernel_size_x + kernel_x;\n";
423 } else {
424 c += " int kernel_index = kernel_y * args.kernel_size_x + kernel_x;\n";
425 }
426 if (weights_are_buffer) {
427 c += " int f_offset = f_base + kernel_index * "
428 "args.src_tensor.Slices() * " +
429 std::to_string(block_size.w) + ";\n";
430 } else {
431 c += " int x_c = kernel_index * args.src_tensor.Slices();\n";
432 }
433 c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n";
434 const bool conditional_read = gpu_info.IsMali();
435 for (int z = 0; z < block_size.z; ++z) {
436 const std::string zind = std::to_string(z);
437 for (int y = 0; y < block_size.y; ++y) {
438 const std::string yind = std::to_string(y);
439 for (int x = 0; x < block_size.x; ++x) {
440 const std::string xind = std::to_string(x);
441 const std::string id = generate_id(xind, yind, zind);
442 std::string address;
443 if (src_def.IsLinear()) {
444 address = "addr" + id;
445 } else {
446 address = "sx" + xind + ", sy" + yind;
447 if (src_def.HasAxis(Axis::DEPTH)) {
448 address += ", sz" + zind;
449 }
450 address += ", s";
451 }
452 if (src_def.ReturnsZeroForNegOneRead()) {
453 c += " FLT4 src" + id + " = args.src_tensor.Read(" + address +
454 "); " + address + " += ds" + id + ";\n";
455 } else {
456 const std::string check = generate_check(xind, yind, zind);
457 if (!check.empty()) {
458 if (conditional_read) {
459 c += " FLT4 src" + id + " = " + check +
460 " ? args.src_tensor.Read(" + address + ") : (FLT4)(0.0f);\n";
461 } else {
462 c += " FLT4 src" + id + " = args.src_tensor.Read(" +
463 address + ") * INIT_FLT(" + check + ");\n";
464 }
465 } else {
466 c += " FLT4 src" + id + " = args.src_tensor.Read(" +
467 address + ");\n";
468 }
469 if (src_def.IsLinear()) {
470 c += " addr" + id + " += ds;\n";
471 }
472 }
473 }
474 }
475 }
476 if (weights_are_buffer) {
477 c += " __global FLT16* weights_cache = "
478 "args.weights.GetPtr(f_offset);\n";
479 c += " f_offset += " + std::to_string(block_size.w) + ";\n";
480 } else {
481 for (int s = 0; s < block_size.w; ++s) {
482 c += absl::Substitute(
483 R"( FLT4 f$1 = args.weights0.Read(dst_s + $0, x_c);
484 FLT4 f$2 = args.weights1.Read(dst_s + $0, x_c);
485 FLT4 f$3 = args.weights2.Read(dst_s + $0, x_c);
486 FLT4 f$4 = args.weights3.Read(dst_s + $0, x_c);
487 )",
488 s, s * 4 + 0, s * 4 + 1, s * 4 + 2, s * 4 + 3);
489 }
490 c += " x_c++;\n";
491 }
492 for (int s = 0; s < block_size.w; ++s) {
493 const std::string sind = std::to_string(s);
494 for (int z = 0; z < block_size.z; ++z) {
495 const std::string zind = std::to_string(z);
496 for (int y = 0; y < block_size.y; ++y) {
497 const std::string yind = std::to_string(y);
498 for (int x = 0; x < block_size.x; ++x) {
499 const std::string xind = std::to_string(x);
500 const std::string id = generate_id(xind, yind, zind);
501 const std::string full_id = generate_id_full(xind, yind, zind, sind);
502 c += " CONV" + sind + "(r" + full_id + ", src" + id + ");\n";
503 }
504 }
505 }
506 }
507 c += " }\n";
508 c += " }\n";
509 c += " }\n";
510 if (src_def.HasAxis(Axis::DEPTH)) {
511 c += " }\n";
512 }
513 for (int s = 0; s < block_size.w; ++s) {
514 const std::string sind = std::to_string(s);
515 c += " if (dst_s < args.dst_tensor.Slices()) {\n";
516 c += " FLT4 bias_val = args.biases.Read(dst_s);\n";
517 for (int z = 0; z < block_size.z; ++z) {
518 const std::string zind = std::to_string(z);
519 for (int y = 0; y < block_size.y; ++y) {
520 const std::string yind = std::to_string(y);
521 for (int x = 0; x < block_size.x; ++x) {
522 const std::string xind = std::to_string(x);
523 const std::string id = generate_id_full(xind, yind, zind, sind);
524 std::string checks =
525 "xc < args.dst_tensor.Width() && yc < args.dst_tensor.Height()";
526 std::string coords = "xc, yc";
527 c += " {\n";
528 c += " int xc = dst_x + args.stride_x * " + xind + ";\n";
529 c += " int yc = dst_y + args.stride_y * " + yind + ";\n";
530 if (src_def.HasAxis(Axis::DEPTH)) {
531 c += " int zc = dst_z + args.stride_z * " + zind + ";\n";
532 checks += " && zc < args.dst_tensor.Depth()";
533 coords += ", zc";
534 }
535 c += " if (" + checks + ") {\n";
536 c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
537 c += " args.dst_tensor.Write(res, " + coords + ", dst_s);\n";
538 c += " }\n";
539 c += " }\n";
540 }
541 }
542 }
543 c += " }\n";
544 c += " dst_s++;\n";
545 }
546 c += "}\n";
547 return c;
548 }
549
BindArguments(ArgumentsBinder * args)550 absl::Status ConvolutionTransposed::BindArguments(ArgumentsBinder* args) {
551 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
552 const int aligned_h =
553 AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
554 RETURN_IF_ERROR(
555 args->SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y)));
556 }
557 return absl::OkStatus();
558 }
559
GetGridSize() const560 int3 ConvolutionTransposed::GetGridSize() const {
561 const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x);
562 const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
563 const int aligned_d = AlignByN(dst_[0]->Depth(), stride_.z * block_size_.z);
564 const int grid_x = DivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch();
565 const int grid_y = DivideRoundUp(aligned_h, block_size_.y) *
566 DivideRoundUp(aligned_d, block_size_.z);
567 const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.w);
568 return int3(grid_x, grid_y, grid_z);
569 }
570
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const571 void ConvolutionTransposed::GetPossibleKernelWorkGroups(
572 TuningType tuning_type, const GpuInfo& gpu_info,
573 const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
574 GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
575 work_groups);
576 }
577
CreateConvolutionTransposed(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)578 ConvolutionTransposed CreateConvolutionTransposed(
579 const GpuInfo& gpu_info, const OperationDef& definition,
580 const ConvolutionTransposedAttributes& attr) {
581 const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
582 ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
583 result.UploadWeights(attr.weights, weights_are_buffer);
584
585 TensorLinearDescriptor desc;
586 desc.storage_type =
587 DeduceLinearStorageType(definition.GetPrimaryStorageType());
588 desc.element_type = definition.GetDataType();
589 desc.UploadLinearData(attr.bias);
590 result.args_.AddObject(
591 "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
592 return result;
593 }
594
CreateConvolutionTransposed3D(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposed3DAttributes & attr)595 ConvolutionTransposed CreateConvolutionTransposed3D(
596 const GpuInfo& gpu_info, const OperationDef& definition,
597 const ConvolutionTransposed3DAttributes& attr) {
598 const bool weights_are_buffer = gpu_info.IsMali() || gpu_info.IsApple();
599 ConvolutionTransposed result(definition, attr, gpu_info, weights_are_buffer);
600 result.UploadWeights(attr.weights, weights_are_buffer);
601
602 TensorLinearDescriptor desc;
603 desc.storage_type =
604 DeduceLinearStorageType(definition.GetPrimaryStorageType());
605 desc.element_type = definition.GetDataType();
606 desc.UploadLinearData(attr.bias);
607 result.args_.AddObject(
608 "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
609 return result;
610 }
611
CreateConvolutionTransposedDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const ConvolutionTransposedAttributes & attr)612 ConvolutionTransposed CreateConvolutionTransposedDynamicWeights(
613 const GpuInfo& gpu_info, const OperationDef& definition,
614 const ConvolutionTransposedAttributes& attr) {
615 const bool weights_are_buffer = gpu_info.IsMali();
616 OperationDef new_def = definition;
617 new_def.src_tensors = {
618 definition.src_tensors[0]}; // leaving only src_tensor def, weights defs
619 // will be added later
620 const DataType weights_type = definition.GetDataType();
621 if (weights_are_buffer) {
622 // add 1 src_tensor(buffer) for weights
623 new_def.src_tensors.push_back(
624 {weights_type, TensorStorageType::BUFFER, Layout::HWC});
625 } else {
626 // add 4 src_tensors(4X textures 2d) for weights
627 new_def.src_tensors.push_back(
628 {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
629 new_def.src_tensors.push_back(
630 {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
631 new_def.src_tensors.push_back(
632 {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
633 new_def.src_tensors.push_back(
634 {weights_type, TensorStorageType::TEXTURE_2D, Layout::HWC});
635 }
636 ConvolutionTransposed result(new_def, attr, gpu_info, weights_are_buffer);
637
638 TensorLinearDescriptor desc;
639 desc.storage_type = DeduceLinearStorageType(new_def.GetPrimaryStorageType());
640 desc.element_type = new_def.GetDataType();
641 desc.UploadLinearData(attr.bias);
642 result.args_.AddObject(
643 "biases", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
644 return result;
645 }
646
647 } // namespace gpu
648 } // namespace tflite
649