• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <cstdint>
21 #include <cstring>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/memory/memory.h"
28 #include "tensorflow/lite/delegates/gpu/common/status.h"
29 #include "tensorflow/lite/delegates/gpu/common/types.h"
30 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 namespace {
36 
37 class AlignedConcatByChannels : public NodeShader {
38  public:
IsSupported(const GenerationContext & ctx)39   static bool IsSupported(const GenerationContext& ctx) {
40     const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
41 
42     // Implementation supports concatenation by channels only.
43     if (attr.axis != Axis::CHANNELS) return false;
44 
45     // Implementation supports concatenation of 2 tensors only.
46     if (ctx.input_shapes.size() != 2) return false;
47 
48     // H and W must be the same for every concatenated tensor.
49     for (int i = 1; i < ctx.input_shapes.size(); i++) {
50       if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
51           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
52         return false;
53       }
54     }
55 
56     // Channels must be aligned by 4 for every concatenated tensor.
57     for (const auto& shape : ctx.input_shapes) {
58       if (shape[3] % 4 != 0) return false;
59     }
60 
61     return true;
62   }
63 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const64   absl::Status GenerateCode(const GenerationContext& ctx,
65                             GeneratedCode* generated_code) const final {
66     if (!IsSupported(ctx)) {
67       return absl::InvalidArgumentError(
68           "This case is not supported by aligned concat");
69     }
70 
71     // Shader below concatenates 2 tensors which channels are aligned by 4
72     std::string source = R"(
73       if (gid.z < $border$) {
74         value_0 = $input_data_0[gid.x, gid.y, gid.z]$;
75       } else {
76         int z = gid.z - $border$;
77         value_0 = $input_data_1[gid.x, gid.y, z]$;
78       }
79 )";
80     *generated_code = {
81         /*parameters=*/{
82             {"border", static_cast<int>(ctx.input_shapes[0][3]) / 4}},
83         /*objects=*/{},
84         /*shared_variables=*/{},
85         /*workload=*/uint3(),
86         /*workgroup=*/uint3(),
87         /*source_code=*/std::move(source),
88         /*input=*/IOStructure::ONLY_DEFINITIONS,
89         /*output=*/IOStructure::AUTO,
90     };
91     return absl::OkStatus();
92   }
93 };
94 
95 class ConcatByAnyChannel : public NodeShader {
96  public:
IsSupported(const GenerationContext & ctx)97   static bool IsSupported(const GenerationContext& ctx) {
98     const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
99 
100     // Implementation supports concatenation by channels only.
101     if (attr.axis != Axis::CHANNELS) return false;
102 
103     // Implementation supports concatenation of more that 1 tensors only.
104     if (ctx.input_shapes.size() <= 1) return false;
105 
106     // H and W must be the same for every concatenated tensor.
107     for (int i = 1; i < ctx.input_shapes.size(); i++) {
108       if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
109           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
110         return false;
111       }
112     }
113 
114     return true;
115   }
116 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const117   absl::Status GenerateCode(const GenerationContext& ctx,
118                             GeneratedCode* generated_code) const final {
119     if (!IsSupported(ctx)) {
120       return absl::UnimplementedError("This case is not supported by concat");
121     }
122 
123     std::string code = DeclareVariables();
124 
125     // "already_written" is used to keep the amount of already joined channels
126     int already_written = 0;
127     // "t" is an id of the next temp* variable.
128     // Generally, temp* variables are used in macros
129     // READ_BUFFER_VEC4(buff, addr, var).
130     // This macros instantiate the variable "var" and
131     // reads the value from buffer "buff" by address "addr"
132     int t = 0;
133     for (int current_input_id = 0; current_input_id < ctx.input_shapes.size();
134          current_input_id++) {
135       // Start joining next inout tensor
136 
137       // Grab channels amount
138       int in_ch = ctx.input_shapes[current_input_id][3];
139       code += PrintStartMessage(current_input_id, in_ch, already_written);
140 
141       // Construct the buffer name associated with this tensor
142       std::string input = "input_data_" + std::to_string(current_input_id);
143 
144       // "reminder" shows us how many cells in 4-element vector are left after
145       // the last write. As example, if we join two tensors both with
146       // 3 channels, after joining the first one we come to this line again
147       // and, when joining the second tensor, the reminder value
148       // will be equal to 1
149       int reminder = already_written % 4;
150 
151       if (reminder == 0) {
152         code += AlignedCase(in_ch, input);
153       } else {
154         code += UnalignedCase(reminder, in_ch, input, &t);
155       }
156       already_written += in_ch;
157     }
158 
159     *generated_code = {
160         /*parameters=*/{},
161         /*objects=*/{},
162         /*shared_variables=*/{},
163         /*workload=*/
164         uint3(static_cast<int>(ctx.output_shapes[0][2]),
165               static_cast<int>(ctx.output_shapes[0][1]), 1),
166         /*workgroup=*/uint3(),
167         /*source_code=*/std::move(code),
168         /*input=*/IOStructure::ONLY_DEFINITIONS,
169         /*output=*/IOStructure::ONLY_DEFINITIONS,
170     };
171     return absl::OkStatus();
172   }
173 
174  private:
175   // Utility function
temp(int t) const176   std::string temp(int t) const { return "temp" + std::to_string(t); }
177 
DeclareVariables() const178   std::string DeclareVariables() const {
179     // "val" is used to collect useful information before the next
180     // upcoming write.
181     return R"(
182 int z = gid.z;
183 vec4 val = vec4(0.0f);
184 
185 )";
186   }
187 
PrintStartMessage(int current_input_id,int in_ch,int already_written) const188   std::string PrintStartMessage(int current_input_id, int in_ch,
189                                 int already_written) const {
190     return "//              Joining " + std::to_string(current_input_id) +
191            " tensor with " + std::to_string(in_ch) +
192            " channels\n//  * * * *\\n// Already wrote " +
193            std::to_string(already_written) + " elements\n\n";
194   }
195 
AlignedCase(int in_ch,const std::string & input) const196   std::string AlignedCase(int in_ch, const std::string& input) const {
197     std::string code;
198     // This branch is for aligned reading and writing, when we can copy
199     // all 4 components at once. Address of the first element to write
200     // should be aligned.
201     // Visual examples:
202     // 1) when copy input_data_0
203     //
204     //       | * * * * | * * * @ | @ @ . . .
205     //         ^
206     // 2) when in the middle of joining process:
207     //
208     //       | X X X X | * * * @ | @ @ . . .
209     //                   ^
210     // Note that amount of * equals to the in_ch
211     //
212     // X - cells were written before
213     // * - you are going to write into these cells
214     // @ - you will fill these cells next cycles
215     // ^ - first elem you start writing from
216     int blocks_amount = DivideRoundUp<int>(in_ch, 4);
217     code += "// Aligned case\n";
218     code += "// I'm going to make " + std::to_string(blocks_amount) +
219             " write(s)\n\n";
220     for (int block = 0; block < blocks_amount; block++) {
221       // Copy full 4-element vector
222       code += "val = $" + input + "[gid.x, gid.y, " + std::to_string(block) +
223               "]$;\n" +
224               "$output_data_0[gid.x, gid.y, z] = val$;\n"
225               // calculate next address to write
226               + "z++; \n\n";
227     }
228     return code;
229   }
230 
UnalignedCase(int reminder,int in_ch,const std::string & input,int * t) const231   std::string UnalignedCase(int reminder, int in_ch, const std::string& input,
232                             int* t) const {
233     // This branch is for copying cell-by-cell. It will never start from the
234     // first tensor input_data_0. This function is splitting in two stages:
235     // 1) Copy the "leftovers" for the previous cells
236     // 2) Copy all other
237     // Visual examples:
238     //
239     //        Stage 1       Stage 2
240     //        -----------   -------------------------
241     // . . X | X  X  X *1 | *2 *2 *2  @ | @  @  . . .
242     //               ^
243     // . . X | X  X *1 *1 | *2 *2 *2 *2 | *2 *2 . . .
244     //             ^
245     // . . X | X *1 *1 *1 | *2  @  @  @ | @  @  . . .
246     //           ^
247     // Note that amount of * equals to the in_ch
248     //
249     // X - cells were written before
250     // *1 - write there at the Stage 1
251     // *2 - write there at the Stage 2
252     // @ - you will fill these cells next cycles
253     // ^ - first elem you start writing from
254 
255     std::string code = "// Unaligned case\n";
256 
257     // Variable "shift" showes how many "empty" cells are left after previous
258     // write. Remember, that this case should is unaligned.
259     // shift now can only be 1, 2 or 3
260     int shift = 4 - reminder;
261     if (shift > in_ch) {
262       shift = in_ch;
263     }
264     code += "\n// Stage 1\n";
265     code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, 0]$;\n";
266     for (int i = 0; i < shift; i++) {
267       // Note that reminder + i has implicitly added 1, cause
268       // reminder by it's nature is an amount, not an index
269       code += "val[" + std::to_string(reminder + i) + "] = " + temp(*t) + "[" +
270               std::to_string(i) + "];\n";
271     }
272     // Rewrite previous value with updated last cells
273     code += "$output_data_0[gid.x, gid.y, z - 1] = val$;\n";
274     (*t)++;
275 
276     // "left_blocks" is equal to an amount of WRITE_BUFFER_VEC4 calls
277     // which will are left for this input to be finally copied
278     int left_blocks = (in_ch - shift) / 4;
279     if ((in_ch - shift) % 4 != 0) {
280       left_blocks++;
281     }
282     if (left_blocks) {
283       code += "\n// Stage 2\n";
284       for (int block = 0; block < left_blocks; block++) {
285         for (int elem = 0; elem < 4; elem++) {
286           if (shift % 4 == 0) {
287             code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, " +
288                     std::to_string(block + 1) + "]$;\n";
289             (*t)++;
290           }
291           code += "val[" + std::to_string(elem) + "] = " + temp(*t - 1) + "[" +
292                   std::to_string(shift % 4) + "];\n";
293           if (shift == in_ch) {
294             break;
295           }
296           shift++;
297         }
298         code += "$output_data_0[gid.x, gid.y, z] = val$;\n";
299         code += "z++;\n";
300       }
301     } else {
302       code += "// No Stage 2\n";
303     }
304     return code;
305   }
306 };
307 
308 class FlatConcatByHeight : public NodeShader {
309  public:
IsSupported(const GenerationContext & ctx)310   static bool IsSupported(const GenerationContext& ctx) {
311     const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
312 
313     // Implementation supports concatenation by height only.
314     if (attr.axis != Axis::HEIGHT) return false;
315 
316     // Implementation supports concatenation of more that 1 tensors only.
317     if (ctx.input_shapes.size() <= 1) return false;
318 
319     // C and W must be the same for every concatenated tensor.
320     for (int i = 1; i < ctx.input_shapes.size(); i++) {
321       if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
322           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
323         return false;
324       }
325     }
326 
327     return true;
328   }
329 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const330   absl::Status GenerateCode(const GenerationContext& ctx,
331                             GeneratedCode* generated_code) const final {
332     std::string code;
333     std::vector<Variable> params;
334     for (int i = 0, shift = 0; i < ctx.input_shapes.size();
335          shift += ctx.input_shapes[i][1], i++) {
336       code += "if (";
337       if (i != 0) {
338         code += "$input_data_" + std::to_string(i - 1) + "_h$ <= gid.y && ";
339       }
340       code +=
341           "gid.y < " + std::to_string(shift + ctx.input_shapes[i][1]) + ") {\n";
342       code += "if (gid.y - " + std::to_string(shift) + " >= $input_data_" +
343               std::to_string(i) + "_h$) return;\n";
344       code += "value_0 = $input_data_" + std::to_string(i) +
345               "[gid.x, gid.y - " + std::to_string(shift) + ", gid.z]$;\n}\n";
346       if (i != ctx.input_shapes.size() - 1) {
347         code += " else ";
348       }
349       params.push_back({"input_data_" + std::to_string(i) + "_h",
350                         static_cast<int>(ctx.input_shapes[i][1])});
351     }
352 
353     *generated_code = {
354         /*parameters=*/std::move(params),
355         /*objects=*/{},
356         /*shared_variables=*/{},
357         /*workload=*/uint3(),
358         /*workgroup=*/uint3(),
359         /*source_code=*/std::move(code),
360         /*input=*/IOStructure::ONLY_DEFINITIONS,
361         /*output=*/IOStructure::AUTO,
362     };
363     return absl::OkStatus();
364   }
365 };
366 
367 class FlatConcatByWidth : public NodeShader {
368  public:
IsSupported(const GenerationContext & ctx)369   static bool IsSupported(const GenerationContext& ctx) {
370     const auto& attr = std::any_cast<const ConcatAttributes&>(ctx.op_attr);
371 
372     // Implementation supports concatenation by width only.
373     if (attr.axis != Axis::WIDTH) return false;
374 
375     // Implementation supports concatenation of more that 1 tensors only.
376     if (ctx.input_shapes.size() <= 1) return false;
377 
378     // C and H must be the same for every concatenated tensor.
379     for (int i = 1; i < ctx.input_shapes.size(); i++) {
380       if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
381           ctx.input_shapes[0][1] != ctx.input_shapes[i][1]) {
382         return false;
383       }
384     }
385 
386     return true;
387   }
388 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const389   absl::Status GenerateCode(const GenerationContext& ctx,
390                             GeneratedCode* generated_code) const final {
391     std::string code;
392     std::vector<Variable> params;
393     for (int i = 0, shift = 0; i < ctx.input_shapes.size();
394          shift += ctx.input_shapes[i][2], i++) {
395       code += "if (";
396       if (i != 0) {
397         code += "$input_data_" + std::to_string(i - 1) + "_w$ <= gid.x && ";
398       }
399       code +=
400           "gid.x < " + std::to_string(shift + ctx.input_shapes[i][2]) + ") {\n";
401       code += "if (gid.x - " + std::to_string(shift) + " >= $input_data_" +
402               std::to_string(i) + "_w$) return;\n";
403       code += "value_0 = $input_data_" + std::to_string(i) + "[gid.x - " +
404               std::to_string(shift) + ", gid.y, gid.z]$;\n}\n";
405       if (i != ctx.input_shapes.size() - 1) {
406         code += " else ";
407       }
408       params.push_back({"input_data_" + std::to_string(i) + "_w",
409                         static_cast<int>(ctx.input_shapes[i][2])});
410     }
411 
412     *generated_code = {
413         /*parameters=*/std::move(params),
414         /*objects=*/{},
415         /*shared_variables=*/{},
416         /*workload=*/uint3(),
417         /*workgroup=*/uint3(),
418         /*source_code=*/std::move(code),
419         /*input=*/IOStructure::ONLY_DEFINITIONS,
420         /*output=*/IOStructure::AUTO,
421     };
422     return absl::OkStatus();
423   }
424 };
425 
426 class FlatConcat : public NodeShader {
427  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const428   absl::Status GenerateCode(const GenerationContext& ctx,
429                             GeneratedCode* generated_code) const final {
430     if (FlatConcatByHeight::IsSupported(ctx)) {
431       return flat_concat_by_height_.GenerateCode(ctx, generated_code);
432     }
433     if (FlatConcatByWidth::IsSupported(ctx)) {
434       return flat_concat_by_width_.GenerateCode(ctx, generated_code);
435     }
436     return absl::InvalidArgumentError(
437         "This case is not supported by flat concat");
438   }
439 
440  private:
441   FlatConcatByHeight flat_concat_by_height_;
442   FlatConcatByWidth flat_concat_by_width_;
443 };
444 
445 }  // namespace
446 
NewAlignedConcatNodeShader()447 std::unique_ptr<NodeShader> NewAlignedConcatNodeShader() {
448   return std::make_unique<AlignedConcatByChannels>();
449 }
450 
NewConcatNodeShader()451 std::unique_ptr<NodeShader> NewConcatNodeShader() {
452   return std::make_unique<ConcatByAnyChannel>();
453 }
454 
NewFlatConcatNodeShader()455 std::unique_ptr<NodeShader> NewFlatConcatNodeShader() {
456   return std::make_unique<FlatConcat>();
457 }
458 
459 }  // namespace gl
460 }  // namespace gpu
461 }  // namespace tflite
462