1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
17
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
31
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36
37 class AlignedConcatByChannels : public NodeShader {
38 public:
IsSupported(const GenerationContext & ctx)39 static bool IsSupported(const GenerationContext& ctx) {
40 const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
41
42 // Implementation supports concatenation by channels only.
43 if (attr.axis != Axis::CHANNELS) return false;
44
45 // Implementation supports concatenation of 2 tensors only.
46 if (ctx.input_shapes.size() != 2) return false;
47
48 // H and W must be the same for every concatenated tensor.
49 for (int i = 1; i < ctx.input_shapes.size(); i++) {
50 if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
51 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
52 return false;
53 }
54 }
55
56 // Channels must be aligned by 4 for every concatenated tensor.
57 for (const auto& shape : ctx.input_shapes) {
58 if (shape[3] % 4 != 0) return false;
59 }
60
61 return true;
62 }
63
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const64 absl::Status GenerateCode(const GenerationContext& ctx,
65 GeneratedCode* generated_code) const final {
66 if (!IsSupported(ctx)) {
67 return absl::InvalidArgumentError(
68 "This case is not supported by aligned concat");
69 }
70
71 // Shader below concatenates 2 tensors which channels are aligned by 4
72 std::string source = R"(
73 if (gid.z < $border$) {
74 value_0 = $input_data_0[gid.x, gid.y, gid.z]$;
75 } else {
76 int z = gid.z - $border$;
77 value_0 = $input_data_1[gid.x, gid.y, z]$;
78 }
79 )";
80 *generated_code = {
81 /*parameters=*/{
82 {"border", static_cast<int>(ctx.input_shapes[0][3]) / 4}},
83 /*objects=*/{},
84 /*shared_variables=*/{},
85 /*workload=*/uint3(),
86 /*workgroup=*/uint3(),
87 /*source_code=*/std::move(source),
88 /*input=*/IOStructure::ONLY_DEFINITIONS,
89 /*output=*/IOStructure::AUTO,
90 };
91 return absl::OkStatus();
92 }
93 };
94
95 class ConcatByAnyChannel : public NodeShader {
96 public:
IsSupported(const GenerationContext & ctx)97 static bool IsSupported(const GenerationContext& ctx) {
98 const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
99
100 // Implementation supports concatenation by channels only.
101 if (attr.axis != Axis::CHANNELS) return false;
102
103 // Implementation supports concatenation of more that 1 tensors only.
104 if (ctx.input_shapes.size() <= 1) return false;
105
106 // H and W must be the same for every concatenated tensor.
107 for (int i = 1; i < ctx.input_shapes.size(); i++) {
108 if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
109 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
110 return false;
111 }
112 }
113
114 return true;
115 }
116
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const117 absl::Status GenerateCode(const GenerationContext& ctx,
118 GeneratedCode* generated_code) const final {
119 if (!IsSupported(ctx)) {
120 return absl::UnimplementedError("This case is not supported by concat");
121 }
122
123 std::string code = DeclareVariables();
124
125 // "already_written" is used to keep the amount of already joined channels
126 int already_written = 0;
127 // "t" is an id of the next temp* variable.
128 // Generally, temp* variables are used in macros
129 // READ_BUFFER_VEC4(buff, addr, var).
130 // This macros instantiate the variable "var" and
131 // reads the value from buffer "buff" by address "addr"
132 int t = 0;
133 for (int current_input_id = 0; current_input_id < ctx.input_shapes.size();
134 current_input_id++) {
135 // Start joining next inout tensor
136
137 // Grab channels amount
138 int in_ch = ctx.input_shapes[current_input_id][3];
139 code += PrintStartMessage(current_input_id, in_ch, already_written);
140
141 // Construct the buffer name associated with this tensor
142 std::string input = "input_data_" + std::to_string(current_input_id);
143
144 // "reminder" shows us how many cells in 4-element vector are left after
145 // the last write. As example, if we join two tensors both with
146 // 3 channels, after joining the first one we come to this line again
147 // and, when joining the second tensor, the reminder value
148 // will be equal to 1
149 int reminder = already_written % 4;
150
151 if (reminder == 0) {
152 code += AlignedCase(in_ch, input);
153 } else {
154 code += UnalignedCase(reminder, in_ch, input, &t);
155 }
156 already_written += in_ch;
157 }
158
159 *generated_code = {
160 /*parameters=*/{},
161 /*objects=*/{},
162 /*shared_variables=*/{},
163 /*workload=*/
164 uint3(static_cast<int>(ctx.output_shapes[0][2]),
165 static_cast<int>(ctx.output_shapes[0][1]), 1),
166 /*workgroup=*/uint3(),
167 /*source_code=*/std::move(code),
168 /*input=*/IOStructure::ONLY_DEFINITIONS,
169 /*output=*/IOStructure::ONLY_DEFINITIONS,
170 };
171 return absl::OkStatus();
172 }
173
174 private:
175 // Utility function
temp(int t) const176 std::string temp(int t) const { return "temp" + std::to_string(t); }
177
DeclareVariables() const178 std::string DeclareVariables() const {
179 // "val" is used to collect useful information before the next
180 // upcoming write.
181 return R"(
182 int z = gid.z;
183 vec4 val = vec4(0.0f);
184
185 )";
186 }
187
PrintStartMessage(int current_input_id,int in_ch,int already_written) const188 std::string PrintStartMessage(int current_input_id, int in_ch,
189 int already_written) const {
190 return "// Joining " + std::to_string(current_input_id) +
191 " tensor with " + std::to_string(in_ch) +
192 " channels\n// * * * *\\n// Already wrote " +
193 std::to_string(already_written) + " elements\n\n";
194 }
195
AlignedCase(int in_ch,const std::string & input) const196 std::string AlignedCase(int in_ch, const std::string& input) const {
197 std::string code;
198 // This branch is for aligned reading and writing, when we can copy
199 // all 4 components at once. Address of the first element to write
200 // should be aligned.
201 // Visual examples:
202 // 1) when copy input_data_0
203 //
204 // | * * * * | * * * @ | @ @ . . .
205 // ^
206 // 2) when in the middle of joining process:
207 //
208 // | X X X X | * * * @ | @ @ . . .
209 // ^
210 // Note that amount of * equals to the in_ch
211 //
212 // X - cells were written before
213 // * - you are going to write into these cells
214 // @ - you will fill these cells next cycles
215 // ^ - first elem you start writing from
216 int blocks_amount = DivideRoundUp<int>(in_ch, 4);
217 code += "// Aligned case\n";
218 code += "// I'm going to make " + std::to_string(blocks_amount) +
219 " write(s)\n\n";
220 for (int block = 0; block < blocks_amount; block++) {
221 // Copy full 4-element vector
222 code += "val = $" + input + "[gid.x, gid.y, " + std::to_string(block) +
223 "]$;\n" +
224 "$output_data_0[gid.x, gid.y, z] = val$;\n"
225 // calculate next address to write
226 + "z++; \n\n";
227 }
228 return code;
229 }
230
UnalignedCase(int reminder,int in_ch,const std::string & input,int * t) const231 std::string UnalignedCase(int reminder, int in_ch, const std::string& input,
232 int* t) const {
233 // This branch is for copying cell-by-cell. It will never start from the
234 // first tensor input_data_0. This function is splitting in two stages:
235 // 1) Copy the "leftovers" for the previous cells
236 // 2) Copy all other
237 // Visual examples:
238 //
239 // Stage 1 Stage 2
240 // ----------- -------------------------
241 // . . X | X X X *1 | *2 *2 *2 @ | @ @ . . .
242 // ^
243 // . . X | X X *1 *1 | *2 *2 *2 *2 | *2 *2 . . .
244 // ^
245 // . . X | X *1 *1 *1 | *2 @ @ @ | @ @ . . .
246 // ^
247 // Note that amount of * equals to the in_ch
248 //
249 // X - cells were written before
250 // *1 - write there at the Stage 1
251 // *2 - write there at the Stage 2
252 // @ - you will fill these cells next cycles
253 // ^ - first elem you start writing from
254
255 std::string code = "// Unaligned case\n";
256
257 // Variable "shift" showes how many "empty" cells are left after previous
258 // write. Remember, that this case should is unaligned.
259 // shift now can only be 1, 2 or 3
260 int shift = 4 - reminder;
261 if (shift > in_ch) {
262 shift = in_ch;
263 }
264 code += "\n// Stage 1\n";
265 code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, 0]$;\n";
266 for (int i = 0; i < shift; i++) {
267 // Note that reminder + i has implicitly added 1, cause
268 // reminder by it's nature is an amount, not an index
269 code += "val[" + std::to_string(reminder + i) + "] = " + temp(*t) + "[" +
270 std::to_string(i) + "];\n";
271 }
272 // Rewrite previous value with updated last cells
273 code += "$output_data_0[gid.x, gid.y, z - 1] = val$;\n";
274 (*t)++;
275
276 // "left_blocks" is equal to an amount of WRITE_BUFFER_VEC4 calls
277 // which will are left for this input to be finally copied
278 int left_blocks = (in_ch - shift) / 4;
279 if ((in_ch - shift) % 4 != 0) {
280 left_blocks++;
281 }
282 if (left_blocks) {
283 code += "\n// Stage 2\n";
284 for (int block = 0; block < left_blocks; block++) {
285 for (int elem = 0; elem < 4; elem++) {
286 if (shift % 4 == 0) {
287 code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, " +
288 std::to_string(block + 1) + "]$;\n";
289 (*t)++;
290 }
291 code += "val[" + std::to_string(elem) + "] = " + temp(*t - 1) + "[" +
292 std::to_string(shift % 4) + "];\n";
293 if (shift == in_ch) {
294 break;
295 }
296 shift++;
297 }
298 code += "$output_data_0[gid.x, gid.y, z] = val$;\n";
299 code += "z++;\n";
300 }
301 } else {
302 code += "// No Stage 2\n";
303 }
304 return code;
305 }
306 };
307
308 class FlatConcatByHeight : public NodeShader {
309 public:
IsSupported(const GenerationContext & ctx)310 static bool IsSupported(const GenerationContext& ctx) {
311 const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
312
313 // Implementation supports concatenation by height only.
314 if (attr.axis != Axis::HEIGHT) return false;
315
316 // Implementation supports concatenation of more that 1 tensors only.
317 if (ctx.input_shapes.size() <= 1) return false;
318
319 // C and W must be the same for every concatenated tensor.
320 for (int i = 1; i < ctx.input_shapes.size(); i++) {
321 if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
322 ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
323 return false;
324 }
325 }
326
327 return true;
328 }
329
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const330 absl::Status GenerateCode(const GenerationContext& ctx,
331 GeneratedCode* generated_code) const final {
332 std::string code;
333 std::vector<Variable> params;
334 for (int i = 0, shift = 0; i < ctx.input_shapes.size();
335 shift += ctx.input_shapes[i][1], i++) {
336 code += "if (";
337 if (i != 0) {
338 code += "$input_data_" + std::to_string(i - 1) + "_h$ <= gid.y && ";
339 }
340 code +=
341 "gid.y < " + std::to_string(shift + ctx.input_shapes[i][1]) + ") {\n";
342 code += "if (gid.y - " + std::to_string(shift) + " >= $input_data_" +
343 std::to_string(i) + "_h$) return;\n";
344 code += "value_0 = $input_data_" + std::to_string(i) +
345 "[gid.x, gid.y - " + std::to_string(shift) + ", gid.z]$;\n}\n";
346 if (i != ctx.input_shapes.size() - 1) {
347 code += " else ";
348 }
349 params.push_back({"input_data_" + std::to_string(i) + "_h",
350 static_cast<int>(ctx.input_shapes[i][1])});
351 }
352
353 *generated_code = {
354 /*parameters=*/std::move(params),
355 /*objects=*/{},
356 /*shared_variables=*/{},
357 /*workload=*/uint3(),
358 /*workgroup=*/uint3(),
359 /*source_code=*/std::move(code),
360 /*input=*/IOStructure::ONLY_DEFINITIONS,
361 /*output=*/IOStructure::AUTO,
362 };
363 return absl::OkStatus();
364 }
365 };
366
367 class FlatConcatByWidth : public NodeShader {
368 public:
IsSupported(const GenerationContext & ctx)369 static bool IsSupported(const GenerationContext& ctx) {
370 const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
371
372 // Implementation supports concatenation by width only.
373 if (attr.axis != Axis::WIDTH) return false;
374
375 // Implementation supports concatenation of more that 1 tensors only.
376 if (ctx.input_shapes.size() <= 1) return false;
377
378 // C and H must be the same for every concatenated tensor.
379 for (int i = 1; i < ctx.input_shapes.size(); i++) {
380 if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
381 ctx.input_shapes[0][1] != ctx.input_shapes[i][1]) {
382 return false;
383 }
384 }
385
386 return true;
387 }
388
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const389 absl::Status GenerateCode(const GenerationContext& ctx,
390 GeneratedCode* generated_code) const final {
391 std::string code;
392 std::vector<Variable> params;
393 for (int i = 0, shift = 0; i < ctx.input_shapes.size();
394 shift += ctx.input_shapes[i][2], i++) {
395 code += "if (";
396 if (i != 0) {
397 code += "$input_data_" + std::to_string(i - 1) + "_w$ <= gid.x && ";
398 }
399 code +=
400 "gid.x < " + std::to_string(shift + ctx.input_shapes[i][2]) + ") {\n";
401 code += "if (gid.x - " + std::to_string(shift) + " >= $input_data_" +
402 std::to_string(i) + "_w$) return;\n";
403 code += "value_0 = $input_data_" + std::to_string(i) + "[gid.x - " +
404 std::to_string(shift) + ", gid.y, gid.z]$;\n}\n";
405 if (i != ctx.input_shapes.size() - 1) {
406 code += " else ";
407 }
408 params.push_back({"input_data_" + std::to_string(i) + "_w",
409 static_cast<int>(ctx.input_shapes[i][2])});
410 }
411
412 *generated_code = {
413 /*parameters=*/std::move(params),
414 /*objects=*/{},
415 /*shared_variables=*/{},
416 /*workload=*/uint3(),
417 /*workgroup=*/uint3(),
418 /*source_code=*/std::move(code),
419 /*input=*/IOStructure::ONLY_DEFINITIONS,
420 /*output=*/IOStructure::AUTO,
421 };
422 return absl::OkStatus();
423 }
424 };
425
426 class FlatConcat : public NodeShader {
427 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const428 absl::Status GenerateCode(const GenerationContext& ctx,
429 GeneratedCode* generated_code) const final {
430 if (FlatConcatByHeight::IsSupported(ctx)) {
431 return flat_concat_by_height_.GenerateCode(ctx, generated_code);
432 }
433 if (FlatConcatByWidth::IsSupported(ctx)) {
434 return flat_concat_by_width_.GenerateCode(ctx, generated_code);
435 }
436 return absl::InvalidArgumentError(
437 "This case is not supported by flat concat");
438 }
439
440 private:
441 FlatConcatByHeight flat_concat_by_height_;
442 FlatConcatByWidth flat_concat_by_width_;
443 };
444
445 } // namespace
446
NewAlignedConcatNodeShader()447 std::unique_ptr<NodeShader> NewAlignedConcatNodeShader() {
448 return std::make_unique<AlignedConcatByChannels>();
449 }
450
NewConcatNodeShader()451 std::unique_ptr<NodeShader> NewConcatNodeShader() {
452 return std::make_unique<ConcatByAnyChannel>();
453 }
454
NewFlatConcatNodeShader()455 std::unique_ptr<NodeShader> NewFlatConcatNodeShader() {
456 return std::make_unique<FlatConcat>();
457 }
458
459 } // namespace gl
460 } // namespace gpu
461 } // namespace tflite
462