1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33
34 class AlignedConcatByChannels : public NodeShader {
35 public:
IsSupported(const GenerationContext & ctx)36 static bool IsSupported(const GenerationContext& ctx) {
37 const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
38
39 // Implementation supports concatenation by channels only.
40 if (attr.axis != Axis::CHANNELS) return false;
41
42 // Implementation supports concatenation of 2 tensors only.
43 if (ctx.input_shapes.size() != 2) return false;
44
45 // H and W must be the same for every concatenated tensor.
46 for (int i = 1; i < ctx.input_shapes.size(); i++) {
47 if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
48 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
49 return false;
50 }
51 }
52
53 // Channels must be aligned by 4 for every concatenated tensor.
54 for (const auto& shape : ctx.input_shapes) {
55 if (shape[3] % 4 != 0) return false;
56 }
57
58 return true;
59 }
60
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const61 absl::Status GenerateCode(const GenerationContext& ctx,
62 GeneratedCode* generated_code) const final {
63 if (!IsSupported(ctx)) {
64 return absl::InvalidArgumentError(
65 "This case is not supported by aligned concat");
66 }
67
68 // Shader below concatenates 2 tensors which channels are aligned by 4
69 std::string source = R"(
70 if (gid.z < $border$) {
71 value_0 = $input_data_0[gid.x, gid.y, gid.z]$;
72 } else {
73 int z = gid.z - $border$;
74 value_0 = $input_data_1[gid.x, gid.y, z]$;
75 }
76 )";
77 *generated_code = {
78 /*parameters=*/{
79 {"border", static_cast<int>(ctx.input_shapes[0][3]) / 4}},
80 /*objects=*/{},
81 /*shared_variables=*/{},
82 /*workload=*/uint3(),
83 /*workgroup=*/uint3(),
84 /*source_code=*/std::move(source),
85 /*input=*/IOStructure::ONLY_DEFINITIONS,
86 /*output=*/IOStructure::AUTO,
87 };
88 return absl::OkStatus();
89 }
90 };
91
92 class ConcatByAnyChannel : public NodeShader {
93 public:
IsSupported(const GenerationContext & ctx)94 static bool IsSupported(const GenerationContext& ctx) {
95 const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
96
97 // Implementation supports concatenation by channels only.
98 if (attr.axis != Axis::CHANNELS) return false;
99
100 // Implementation supports concatenation of more that 1 tensors only.
101 if (ctx.input_shapes.size() <= 1) return false;
102
103 // H and W must be the same for every concatenated tensor.
104 for (int i = 1; i < ctx.input_shapes.size(); i++) {
105 if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
106 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
107 return false;
108 }
109 }
110
111 return true;
112 }
113
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const114 absl::Status GenerateCode(const GenerationContext& ctx,
115 GeneratedCode* generated_code) const final {
116 if (!IsSupported(ctx)) {
117 return absl::UnimplementedError("This case is not supported by concat");
118 }
119
120 std::string code = DeclareVariables();
121
122 // "already_written" is used to keep the amount of already joined channels
123 int already_written = 0;
124 // "t" is an id of the next temp* variable.
125 // Generally, temp* variables are used in macros
126 // READ_BUFFER_VEC4(buff, addr, var).
127 // This macros instantiate the variable "var" and
128 // reads the value from buffer "buff" by address "addr"
129 int t = 0;
130 for (int current_input_id = 0; current_input_id < ctx.input_shapes.size();
131 current_input_id++) {
132 // Start joining next inout tensor
133
134 // Grab channels amount
135 int in_ch = ctx.input_shapes[current_input_id][3];
136 code += PrintStartMessage(current_input_id, in_ch, already_written);
137
138 // Construct the buffer name associated with this tensor
139 std::string input = "input_data_" + std::to_string(current_input_id);
140
141 // "reminder" shows us how many cells in 4-element vector are left after
142 // the last write. As example, if we join two tensors both with
143 // 3 channels, after joining the first one we come to this line again
144 // and, when joining the second tensor, the reminder value
145 // will be equal to 1
146 int reminder = already_written % 4;
147
148 if (reminder == 0) {
149 code += AlignedCase(in_ch, input);
150 } else {
151 code += UnalignedCase(reminder, in_ch, input, &t);
152 }
153 already_written += in_ch;
154 }
155
156 *generated_code = {
157 /*parameters=*/{},
158 /*objects=*/{},
159 /*shared_variables=*/{},
160 /*workload=*/
161 uint3(static_cast<int>(ctx.output_shapes[0][2]),
162 static_cast<int>(ctx.output_shapes[0][1]), 1),
163 /*workgroup=*/uint3(),
164 /*source_code=*/std::move(code),
165 /*input=*/IOStructure::ONLY_DEFINITIONS,
166 /*output=*/IOStructure::ONLY_DEFINITIONS,
167 };
168 return absl::OkStatus();
169 }
170
171 private:
172 // Utility function
temp(int t) const173 std::string temp(int t) const { return "temp" + std::to_string(t); }
174
DeclareVariables() const175 std::string DeclareVariables() const {
176 // "val" is used to collect useful information before the next
177 // upcoming write.
178 return R"(
179 int z = gid.z;
180 vec4 val = vec4(0.0f);
181
182 )";
183 }
184
PrintStartMessage(int current_input_id,int in_ch,int already_written) const185 std::string PrintStartMessage(int current_input_id, int in_ch,
186 int already_written) const {
187 return "// Joining " + std::to_string(current_input_id) +
188 " tensor with " + std::to_string(in_ch) +
189 " channels\n// * * * *\\n// Already wrote " +
190 std::to_string(already_written) + " elements\n\n";
191 }
192
AlignedCase(int in_ch,const std::string & input) const193 std::string AlignedCase(int in_ch, const std::string& input) const {
194 std::string code;
195 // This branch is for aligned reading and writing, when we can copy
196 // all 4 components at once. Address of the first element to write
197 // should be aligned.
198 // Visual examples:
199 // 1) when copy input_data_0
200 //
201 // | * * * * | * * * @ | @ @ . . .
202 // ^
203 // 2) when in the middle of joining process:
204 //
205 // | X X X X | * * * @ | @ @ . . .
206 // ^
207 // Note that amount of * equals to the in_ch
208 //
209 // X - cells were written before
210 // * - you are going to write into these cells
211 // @ - you will fill these cells next cycles
212 // ^ - first elem you start writing from
213 int blocks_amount = DivideRoundUp<int>(in_ch, 4);
214 code += "// Aligned case\n";
215 code += "// I'm going to make " + std::to_string(blocks_amount) +
216 " write(s)\n\n";
217 for (int block = 0; block < blocks_amount; block++) {
218 // Copy full 4-element vector
219 code += "val = $" + input + "[gid.x, gid.y, " + std::to_string(block) +
220 "]$;\n" +
221 "$output_data_0[gid.x, gid.y, z] = val$;\n"
222 // calculate next address to write
223 + "z++; \n\n";
224 }
225 return code;
226 }
227
UnalignedCase(int reminder,int in_ch,const std::string & input,int * t) const228 std::string UnalignedCase(int reminder, int in_ch, const std::string& input,
229 int* t) const {
230 // This branch is for copying cell-by-cell. It will never start from the
231 // first tensor input_data_0. This function is splitting in two stages:
232 // 1) Copy the "leftovers" for the previous cells
233 // 2) Copy all other
234 // Visual examples:
235 //
236 // Stage 1 Stage 2
237 // ----------- -------------------------
238 // . . X | X X X *1 | *2 *2 *2 @ | @ @ . . .
239 // ^
240 // . . X | X X *1 *1 | *2 *2 *2 *2 | *2 *2 . . .
241 // ^
242 // . . X | X *1 *1 *1 | *2 @ @ @ | @ @ . . .
243 // ^
244 // Note that amount of * equals to the in_ch
245 //
246 // X - cells were written before
247 // *1 - write there at the Stage 1
248 // *2 - write there at the Stage 2
249 // @ - you will fill these cells next cycles
250 // ^ - first elem you start writing from
251
252 std::string code = "// Unaligned case\n";
253
254 // Variable "shift" showes how many "empty" cells are left after previous
255 // write. Remember, that this case should is unaligned.
256 // shift now can only be 1, 2 or 3
257 int shift = 4 - reminder;
258 if (shift > in_ch) {
259 shift = in_ch;
260 }
261 code += "\n// Stage 1\n";
262 code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, 0]$;\n";
263 for (int i = 0; i < shift; i++) {
264 // Note that reminder + i has implicitly added 1, cause
265 // reminder by it's nature is an amount, not an index
266 code += "val[" + std::to_string(reminder + i) + "] = " + temp(*t) + "[" +
267 std::to_string(i) + "];\n";
268 }
269 // Rewrite previous value with updated last cells
270 code += "$output_data_0[gid.x, gid.y, z - 1] = val$;\n";
271 (*t)++;
272
273 // "left_blocks" is equal to an amount of WRITE_BUFFER_VEC4 calls
274 // which will are left for this input to be finally copied
275 int left_blocks = (in_ch - shift) / 4;
276 if ((in_ch - shift) % 4 != 0) {
277 left_blocks++;
278 }
279 if (left_blocks) {
280 code += "\n// Stage 2\n";
281 for (int block = 0; block < left_blocks; block++) {
282 for (int elem = 0; elem < 4; elem++) {
283 if (shift % 4 == 0) {
284 code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, " +
285 std::to_string(block + 1) + "]$;\n";
286 (*t)++;
287 }
288 code += "val[" + std::to_string(elem) + "] = " + temp(*t - 1) + "[" +
289 std::to_string(shift % 4) + "];\n";
290 if (shift == in_ch) {
291 break;
292 }
293 shift++;
294 }
295 code += "$output_data_0[gid.x, gid.y, z] = val$;\n";
296 code += "z++;\n";
297 }
298 } else {
299 code += "// No Stage 2\n";
300 }
301 return code;
302 }
303 };
304
305 class FlatConcatByHeight : public NodeShader {
306 public:
IsSupported(const GenerationContext & ctx)307 static bool IsSupported(const GenerationContext& ctx) {
308 const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
309
310 // Implementation supports concatenation by height only.
311 if (attr.axis != Axis::HEIGHT) return false;
312
313 // Implementation supports concatenation of more that 1 tensors only.
314 if (ctx.input_shapes.size() <= 1) return false;
315
316 // C and W must be the same for every concatenated tensor.
317 for (int i = 1; i < ctx.input_shapes.size(); i++) {
318 if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
319 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
320 return false;
321 }
322 }
323
324 return true;
325 }
326
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const327 absl::Status GenerateCode(const GenerationContext& ctx,
328 GeneratedCode* generated_code) const final {
329 std::string code;
330 std::vector<Variable> params;
331 for (int i = 0, shift = 0; i < ctx.input_shapes.size();
332 shift += ctx.input_shapes[i][1], i++) {
333 code += "if (";
334 if (i != 0) {
335 code += "$input_data_" + std::to_string(i - 1) + "_h$ <= gid.y && ";
336 }
337 code +=
338 "gid.y < " + std::to_string(shift + ctx.input_shapes[i][1]) + ") {\n";
339 code += "if (gid.y - " + std::to_string(shift) + " >= $input_data_" +
340 std::to_string(i) + "_h$) return;\n";
341 code += "value_0 = $input_data_" + std::to_string(i) +
342 "[gid.x, gid.y - " + std::to_string(shift) + ", gid.z]$;\n}\n";
343 if (i != ctx.input_shapes.size() - 1) {
344 code += " else ";
345 }
346 params.push_back({"input_data_" + std::to_string(i) + "_h",
347 static_cast<int>(ctx.input_shapes[i][1])});
348 }
349
350 *generated_code = {
351 /*parameters=*/std::move(params),
352 /*objects=*/{},
353 /*shared_variables=*/{},
354 /*workload=*/uint3(),
355 /*workgroup=*/uint3(),
356 /*source_code=*/std::move(code),
357 /*input=*/IOStructure::ONLY_DEFINITIONS,
358 /*output=*/IOStructure::AUTO,
359 };
360 return absl::OkStatus();
361 }
362 };
363
364 class FlatConcatByWidth : public NodeShader {
365 public:
IsSupported(const GenerationContext & ctx)366 static bool IsSupported(const GenerationContext& ctx) {
367 const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
368
369 // Implementation supports concatenation by width only.
370 if (attr.axis != Axis::WIDTH) return false;
371
372 // Implementation supports concatenation of more that 1 tensors only.
373 if (ctx.input_shapes.size() <= 1) return false;
374
375 // C and H must be the same for every concatenated tensor.
376 for (int i = 1; i < ctx.input_shapes.size(); i++) {
377 if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
378 ctx.input_shapes[0][1] != ctx.input_shapes[i][1]) {
379 return false;
380 }
381 }
382
383 return true;
384 }
385
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const386 absl::Status GenerateCode(const GenerationContext& ctx,
387 GeneratedCode* generated_code) const final {
388 std::string code;
389 std::vector<Variable> params;
390 for (int i = 0, shift = 0; i < ctx.input_shapes.size();
391 shift += ctx.input_shapes[i][2], i++) {
392 code += "if (";
393 if (i != 0) {
394 code += "$input_data_" + std::to_string(i - 1) + "_w$ <= gid.x && ";
395 }
396 code +=
397 "gid.x < " + std::to_string(shift + ctx.input_shapes[i][2]) + ") {\n";
398 code += "if (gid.x - " + std::to_string(shift) + " >= $input_data_" +
399 std::to_string(i) + "_w$) return;\n";
400 code += "value_0 = $input_data_" + std::to_string(i) + "[gid.x - " +
401 std::to_string(shift) + ", gid.y, gid.z]$;\n}\n";
402 if (i != ctx.input_shapes.size() - 1) {
403 code += " else ";
404 }
405 params.push_back({"input_data_" + std::to_string(i) + "_w",
406 static_cast<int>(ctx.input_shapes[i][2])});
407 }
408
409 *generated_code = {
410 /*parameters=*/std::move(params),
411 /*objects=*/{},
412 /*shared_variables=*/{},
413 /*workload=*/uint3(),
414 /*workgroup=*/uint3(),
415 /*source_code=*/std::move(code),
416 /*input=*/IOStructure::ONLY_DEFINITIONS,
417 /*output=*/IOStructure::AUTO,
418 };
419 return absl::OkStatus();
420 }
421 };
422
423 class FlatConcat : public NodeShader {
424 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const425 absl::Status GenerateCode(const GenerationContext& ctx,
426 GeneratedCode* generated_code) const final {
427 if (FlatConcatByHeight::IsSupported(ctx)) {
428 return flat_concat_by_height_.GenerateCode(ctx, generated_code);
429 }
430 if (FlatConcatByWidth::IsSupported(ctx)) {
431 return flat_concat_by_width_.GenerateCode(ctx, generated_code);
432 }
433 return absl::InvalidArgumentError(
434 "This case is not supported by flat concat");
435 }
436
437 private:
438 FlatConcatByHeight flat_concat_by_height_;
439 FlatConcatByWidth flat_concat_by_width_;
440 };
441
442 } // namespace
443
NewAlignedConcatNodeShader()444 std::unique_ptr<NodeShader> NewAlignedConcatNodeShader() {
445 return absl::make_unique<AlignedConcatByChannels>();
446 }
447
NewConcatNodeShader()448 std::unique_ptr<NodeShader> NewConcatNodeShader() {
449 return absl::make_unique<ConcatByAnyChannel>();
450 }
451
NewFlatConcatNodeShader()452 std::unique_ptr<NodeShader> NewFlatConcatNodeShader() {
453 return absl::make_unique<FlatConcat>();
454 }
455
456 } // namespace gl
457 } // namespace gpu
458 } // namespace tflite
459